diff --git a/dwave/plugins/torch/models/boltzmann_machine.py b/dwave/plugins/torch/models/boltzmann_machine.py index 8bd0350..7f9bb44 100644 --- a/dwave/plugins/torch/models/boltzmann_machine.py +++ b/dwave/plugins/torch/models/boltzmann_machine.py @@ -43,7 +43,7 @@ spread = AggregatedSamples.spread -__all__ = ["GraphRestrictedBoltzmannMachine"] +__all__ = ["GraphRestrictedBoltzmannMachine", "RestrictedBoltzmannMachine"] class GraphRestrictedBoltzmannMachine(torch.nn.Module): @@ -662,3 +662,119 @@ def estimate_beta(self, spins: torch.Tensor) -> float: bqm = BinaryQuadraticModel.from_ising(*self.to_ising(1)) beta = 1 / mple(bqm, (spins.detach().cpu().numpy(), self._nodes))[0] return beta + +class RestrictedBoltzmannMachine(torch.nn.Module): + """A Restricted Boltzmann Machine (RBM) model. + + This class defines the parameterization of a binary RBM. + Training using Persistent Contrastive Divergence (PCD) must be + performed externally using separate sampler and optimizer classes. + + Args: + n_visible (int): Number of visible units. + n_hidden (int): Number of hidden units. + """ + + def __init__( + self, + n_visible: int, + n_hidden: int, + ) -> None: + super().__init__() + + # Model hyperparameters + self._n_visible = n_visible + self._n_hidden = n_hidden + + # Initialize model parameters + # initialize weights + self._weights = torch.nn.Parameter( + 0.1 * torch.randn(n_visible, n_hidden) + ) + # initialize visible units biases. + self._visible_biases = torch.nn.Parameter( + 0.5 * torch.ones(n_visible) + ) + # initialize hidden units biases. + self._hidden_biases = torch.nn.Parameter( + 0.5 * torch.ones(n_hidden) + ) + + @property + def n_visible(self) -> int: + """Number of visible units.""" + return self._n_visible + + @property + def n_hidden(self) -> int: + """Number of hidden units.""" + return self._n_hidden + + @property + def weights(self) -> torch.Tensor: + """Weights of the RBM.""" + return self._weights + + @property + def visible_biases(self) -> torch.Tensor: + """Visible biases of the RBM.""" + return self._visible_biases + + @property + def hidden_biases(self) -> torch.Tensor: + """Hidden biases of the RBM.""" + return self._hidden_biases + + def sample_hidden(self, visible: torch.Tensor) -> torch.Tensor: + """Sample from the distribution P(h|v). + + Args: + visible (torch.Tensor): Tensor of shape (batch_size, n_visible) + representing the states of visible units. + + Returns: + torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing + sampled hidden units. + """ + hidden_probs = torch.sigmoid(self._hidden_biases + visible @ self._weights) + return torch.bernoulli(hidden_probs) + + def sample_visible(self, hidden: torch.Tensor) -> torch.Tensor: + """Sample from the distribution P(v|h). + + Args: + hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden) + representing the states of hidden units. + Returns: + torch.Tensor: Binary tensor of shape (batch_size, n_visible) representing + sampled visible units. + """ + visible_probs = torch.sigmoid(self._visible_biases + hidden @ self._weights.t()) + return torch.bernoulli(visible_probs) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Computes the RBM free energy of a batch of visible units averaged over the batch. + + The free energy F(x) for a visible vector x is: + + .. math:: + F(x) = - x · visible_biases + - sum_{j=1}^{n_hidden} log(1 + exp(hidden_biases[j] + (x · weights)_j)) + + Args: + x (torch.Tensor): Tensor of shape (batch_size, n_visible) representing the visible layer. + + Returns: + torch.Tensor: Scalar tensor representing the **average free energy** over the batch. + """ + + v_term = (x * self._visible_biases).sum(dim=1) + + hidden_pre_activation = x @ self._weights + self._hidden_biases + + h_term = torch.sum(torch.nn.functional.softplus(hidden_pre_activation), dim=1) + + free_energy_per_sample = -v_term - h_term + + # average over batch + return free_energy_per_sample.mean() diff --git a/dwave/plugins/torch/samplers/pcd_sampler.py b/dwave/plugins/torch/samplers/pcd_sampler.py new file mode 100644 index 0000000..364132b --- /dev/null +++ b/dwave/plugins/torch/samplers/pcd_sampler.py @@ -0,0 +1,58 @@ +import torch +from dwave.plugins.torch.models.boltzmann_machine import ( + RestrictedBoltzmannMachine as RBM, +) + +class PCDSampler: + """Persistent Contrastive Divergence (PCD) sampler for RBMs. + + This sampler maintains a persistent Markov chain of visible states + across minibatches and performs Gibbs sampling using the RBM’s + sampling functions. + + Args: + rbm (RBM): The RBM model from which the sampler draws samples. + """ + def __init__(self, rbm: RBM): + self.rbm = rbm + + # Stores the last visible states to initialize the Markov chain in Persistent Contrastive Divergence (PCD) + self.previous_visible_values = None + + def sample( + self, + batch_size: int, + gibbs_steps: int, + start_visible: torch.Tensor | None = None, + ): + """Generate a sample of visible and hidden units using gibbs sampling. + + Args: + batch_size (int): Number of samples to generate. + gibbs_steps (int): Number of Gibbs sampling steps to perform. + start_visible (torch.Tensor | None, optional): Initial visible states to + start the Gibbs chain (shape: [batch_size, n_visible]). If None, + a random Gaussian initialization is used. + + Returns: + tuple[torch.Tensor, torch.Tensor]: A tuple of (visible, hidden) from the last Gibbs step: + - visible: (batch_size, n_visible) + - hidden: (batch_size, n_hidden) + """ + if start_visible is None: + visible_values = torch.randn( + batch_size, self.rbm.n_visible, device=self.rbm.weights.device + ) + else: + visible_values = start_visible + + hidden_values = None + + for _ in range(gibbs_steps): + hidden_values = self.rbm.sample_hidden(visible_values) + visible_values = self.rbm.sample_visible(hidden_values) + + # Store samples to initialize the next Markov chain with (PCD) + self.previous_visible_values = visible_values.detach() + + return visible_values, hidden_values diff --git a/examples/rbm_image_generation.py b/examples/rbm_image_generation.py new file mode 100644 index 0000000..6e920d9 --- /dev/null +++ b/examples/rbm_image_generation.py @@ -0,0 +1,269 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +from torch.utils.data import DataLoader +from dwave.plugins.torch.models.boltzmann_machine import ( + RestrictedBoltzmannMachine as RBM, +) +from torchvision import transforms, datasets +import matplotlib.pyplot as plt +from torch.optim import SGD +from dwave.plugins.torch.samplers.pcd_sampler import PCDSampler + + +def load_binarized_mnist(dataset_path: str = "data") -> datasets.MNIST: + """Load the MNIST dataset and binarize it (pixels >= 0.5 become 1, else 0). + + Args: + dataset_path (str): Path to download/store the MNIST dataset. Defaults to "data". + + Returns: + datasets.MNIST: Binarized MNIST training dataset. + """ + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Lambda(lambda x: (x >= 0.5).float())] + ) + + train_dataset = datasets.MNIST( + root=dataset_path, train=True, transform=transform, download=True + ) + return train_dataset + + +def contrastive_divergence( + rbm: RBM, + batch: torch.Tensor, + n_gibbs_steps: int, + sampler: PCDSampler, + optimizer: torch.optim.Optimizer, +) -> torch.Tensor: + """Perform one step of Contrastive Divergence (CD-k). + + Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states + for Gibbs sampling across batches. + Gradients are applied via the provided PyTorch optimizer. + + Args: + batch (torch.Tensor): A batch of input data of shape (batch_size, n_visible). + n_gibbs_steps (int): Number of Gibbs sampling steps per epoch. + sampler (PCDSampler): Sampler responsible for producing negative-phase samples. + optimizer (torch.optim.Optimizer): PyTorch optimizer. + + Returns: + torch.Tensor: The reconstruction error (L1 norm) for the batch. + """ + # Positive phase (data-driven) + hidden_probs = torch.sigmoid(batch @ rbm.weights + rbm.hidden_biases) + + weight_grads = torch.matmul(batch.t(), hidden_probs) + visible_bias_grads = batch.clone() + hidden_bias_grads = hidden_probs.clone() + + batch_size = batch.size(0) + + # Negative phase (model-driven) + # Sample from the model using gibbs sampling + visible_values, hidden_values = sampler.sample( + batch_size, + gibbs_steps=n_gibbs_steps, + start_visible=sampler.previous_visible_values, + ) + + visible_values = visible_values.detach() + hidden_values = hidden_values.detach() + + # Compute the gradients for negative phase + weight_grads -= torch.matmul(visible_values.t(), hidden_values) + visible_bias_grads -= visible_values + hidden_bias_grads -= hidden_values + + # Average across the batch + weight_grads /= batch_size + visible_bias_grads = torch.mean(visible_bias_grads, dim=0) + hidden_bias_grads = torch.mean(hidden_bias_grads, dim=0) + + # Apply gradients via optimizer + rbm.weights.grad = -weight_grads + rbm.visible_biases.grad = -visible_bias_grads + rbm.hidden_biases.grad = -hidden_bias_grads + + optimizer.step() + optimizer.zero_grad() + + # Compute reconstruction error (L1 norm) + reconstruction = rbm.sample_visible(rbm.sample_hidden(batch)) + reconstruction = reconstruction.detach() + reconstruction_error = torch.sum(torch.abs(batch - reconstruction)) + + return reconstruction_error + +def train_loop( + train_loader: DataLoader, + rbm: RBM, + n_epochs: int, + n_gibbs_steps: int, + sampler: PCDSampler, + optimizer: torch.optim.Optimizer, +) -> None: + """Train the RBM using contrastive divergence with a given PCDSampler and optimizer. + + Args: + train_loader (DataLoader): PyTorch DataLoader for training data. + rbm (RBM): Restricted Boltzmann Machine instance. + n_epochs (int): Number of training epochs. + n_gibbs_steps (int): Number of Gibbs sampling steps per CD update. + sampler (PCDSampler): sampler responsible for producing negative-phase samples. + optimizer (torch.optim.Optimizer): PyTorch optimizer. + """ + device = rbm._weights.device + for epoch in range(n_epochs): + total_error = 0 + num_examples = 0 + for batch, _ in train_loader: + # flatten input data + batch = batch.reshape(batch.size(0), rbm.n_visible).to(device) + + # Perform one step of contrastive divergence and accumulate error + error = contrastive_divergence( + rbm, batch, n_gibbs_steps, sampler, optimizer + ) + total_error += error + num_examples += batch.size(0) + average_error = total_error / num_examples # Average reconstruction error + print( + f"Epoch {epoch + 1}/{n_epochs} - Avg reconstruction error: {average_error:.4f}" + ) + + +def generate_and_save_images( + sampler: PCDSampler, + rows: int = 8, + columns: int = 8, + steps: int = 1000, + output_dir: str = "samples", + output_filename: str = "generated_images.png", +) -> PCDSampler: + """Generate samples from a trained RBM and save them as a grid of images. + + Args: + sampler (PCDSampler): sampler to generate samples from the trained RBM. + rows (int): Number of rows in the output image grid. Defaults to 8. + columns (int): Number of columns in the output image grid. Defaults to 8. + steps (int): Number of Gibbs sampling steps for generation. Defaults to 1000. + output_dir (str): Directory to save the generated images. Defaults to "samples". + output_filename (str): File name for saving the generated image grid. Defaults to "generated_images.png". + """ + os.makedirs(output_dir, exist_ok=True) + output_path = os.path.join(output_dir, output_filename) + + num_images = rows * columns + + # Generate batch of images + samples, _ = sampler.sample(num_images, gibbs_steps=steps) + + # for SpinRBM + #samples = ((samples + 1) / 2).view(num_images, 28, 28).detach().cpu().numpy() # convert -1/+1 → 0/1 + samples = samples.view(num_images, 28, 28).detach().cpu().numpy() + + # Plot grid of images + fig, axs = plt.subplots(rows, columns, figsize=(columns, rows)) + + idx = 0 + for r in range(rows): + for c in range(columns): + axs[r, c].imshow(samples[idx], cmap="gray") + axs[r, c].axis("off") + idx += 1 + + fig.suptitle("Generated images from RBM trained on MNIST", fontsize=18) + plt.tight_layout(rect=[0, 0, 1, 0.96]) + plt.savefig(output_path, bbox_inches="tight", pad_inches=0.1) + plt.show() + print(f"Generated {num_images} samples in {output_dir}/{output_filename}") + + +def train_rbm( + n_visible: int, + n_hidden: int, + n_gibbs_steps: int, + learning_rate: float, + momentum: float, + weight_decay: float, + n_epochs: int, + batch_size: int, + dataset_path: str = "data", +) -> PCDSampler: + """Train an RBM on MNIST and generate sample images. + + Args: + n_visible (int, optional): Number of visible units. + n_hidden (int, optional): Number of hidden units. + n_gibbs_steps (int, optional): Number of Gibbs sampling steps per CD update. + learning_rate (float, optional): Base learning rate for CD updates. + momentum (float, optional): Momentum coefficient for CD updates. + weight_decay (float, optional): Weight decay (L2 regularization) coefficient. + n_epochs (int, optional): Number of training epochs. + batch_size (int, optional): Batch size for training. + dataset_path (str, optional): Path to download/store the MNIST dataset. Defaults to "data". + Returns: + PCDSampler: The sampler used for training the RBM. + """ + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + print(f"Using {device} device") + + # Load MNIST data + print("Loading MNIST dataset...") + train_dataset = load_binarized_mnist(dataset_path) + + # Create data loader + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, drop_last=True + ) + + # Initialize RBM + rbm = RBM(n_visible, n_hidden).to(device) + + # Initialize PCD Sampler + sampler = PCDSampler(rbm) + + optimizer = SGD( + [rbm.weights, rbm.visible_biases, rbm.hidden_biases], + lr=learning_rate, + momentum=momentum, + weight_decay=weight_decay, + ) + + # Train RBM + train_loop(train_loader, rbm, n_epochs, n_gibbs_steps, sampler, optimizer) + + return sampler + + +if __name__ == "__main__": + # Run an example of fitting a Restricted Boltzmann Machine to the MNIST dataset + sampler = train_rbm( + n_visible=784, + n_hidden=500, + n_gibbs_steps=10, + learning_rate=1e-3, + momentum=0.5, + weight_decay=1e-7, + n_epochs=50, + batch_size=64, + ) + # Generate and save samples + generate_and_save_images(sampler) diff --git a/releasenotes/notes/add-rbm-0b2134a1615ed5b3.yaml b/releasenotes/notes/add-rbm-0b2134a1615ed5b3.yaml new file mode 100644 index 0000000..782dd56 --- /dev/null +++ b/releasenotes/notes/add-rbm-0b2134a1615ed5b3.yaml @@ -0,0 +1,4 @@ +--- +features: + - Add ``Restricted Boltzmann Machine`` class for training RBMs + using Persistant Contrastive Divergence algorithm. diff --git a/tests/test_boltzmann_machine.py b/tests/test_boltzmann_machine.py index 069cadf..7e16b1e 100644 --- a/tests/test_boltzmann_machine.py +++ b/tests/test_boltzmann_machine.py @@ -21,6 +21,7 @@ from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple +from dwave.plugins.torch.models.boltzmann_machine import RestrictedBoltzmannMachine as RBM class TestGraphRestrictedBoltzmannMachine(unittest.TestCase): def setUp(self) -> None: @@ -407,6 +408,164 @@ def test_quasi_objective_gradient_hidden_units(self): # the sufficient statistics of the average spins. torch.testing.assert_close(grad, grad_auto) +class TestRBM(unittest.TestCase): + def setUp(self): + # Small RBM for testing + self.rbm = RBM(n_visible=4, n_hidden=3) + + # Common input data for CD tests + self.batch = torch.tensor([[1.0, 0.0, 1.0, 1.0], [0.0, 1.0, 0.0, 1.0]]) + + # Shared CD kwargs + self.cd_kwargs = dict( + epoch=0, + n_gibbs_steps=1, + learning_rate=0.1, + momentum_coefficient=0.5, + weight_decay=0.0, + n_epochs=10, + ) + + def test_sample_hidden_shape(self): + visible = torch.randint(0, 2, (5, self.rbm.n_visible)).float() + hidden = self.rbm.sample_hidden(visible) + # Ensure shape is correct + self.assertEqual(hidden.shape, (5, self.rbm.n_hidden)) + + def test_sample_hidden_binary(self): + visible = torch.randint(0, 2, (5, self.rbm.n_visible)).float() + hidden = self.rbm.sample_hidden(visible) + + # Ensure output is binary + self.assertTrue(torch.all((hidden == 0) | (hidden == 1))) + + @parameterized.expand( + [ + ("all_ones", 1000.0, 1), + ("all_zeroes", -1000.0, 0), + ] + ) + def test_sample_hidden_saturation(self, name, bias_value, expected_value): + """ + Test that sample_hidden saturates correctly when the hidden biases + are set to very large positive or negative values. + + If hidden_bias[j] → +∞ + sigmoid(hidden_bias + visible @ weights) → 1 + bernoulli(1) → always 1 + + If hidden_bias[j] → -∞ + sigmoid(hidden_bias + visible @ weights) → 0 + bernoulli(0) → always 0 + """ + + # Set all hidden biases to an extreme constant + with torch.no_grad(): + self.rbm._hidden_biases.fill_(bias_value) + + # The visible input does not matter in saturation conditions + visible = torch.zeros(5, self.rbm.n_visible) + + # Sample hidden units + hidden = self.rbm.sample_hidden(visible) + + # Assert that all hidden units match the expected saturated value + self.assertTrue(torch.all(hidden == expected_value)) + + def test_sample_visible_shape(self): + hidden = torch.randint(0, 2, (5, self.rbm.n_hidden)).float() + visible = self.rbm.sample_visible(hidden) + + # Ensure shape is correct + self.assertEqual(visible.shape, (5, self.rbm.n_visible)) + + def test_sample_visible_binary(self): + hidden = torch.randint(0, 2, (5, self.rbm.n_hidden)).float() + visible = self.rbm.sample_visible(hidden) + + # Ensure output is binary + self.assertTrue(torch.all((visible == 0) | (visible == 1)).item()) + + @parameterized.expand( + [ + ("all_ones", 1000.0, 1), + ("all_zeroes", -1000.0, 0), + ] + ) + def test_sample_visible_saturation(self, name, bias_value, expected_value): + """ + Test that sample_visible saturates correctly when the visible biases + are set to very large positive or negative values. + + If visible_bias → +∞: sigmoid → 1 → bernoulli(1) → always 1 + If visible_bias → -∞: sigmoid → 0 → bernoulli(0) → always 0 + """ + + # Large positive/negative bias makes sigmoid output deterministic + with torch.no_grad(): + self.rbm._visible_biases.fill_(bias_value) + + # Hidden input doesn't matter when biases dominate + hidden = torch.zeros(5, self.rbm.n_hidden) + + visible = self.rbm.sample_visible(hidden) + + self.assertTrue(torch.all(visible == expected_value).item()) + + def test_forward_scalar_output(self): + """Forward should return a scalar tensor.""" + batch = torch.randn(5, self.rbm.n_visible) + out = self.rbm.forward(batch) + self.assertIsInstance(out, torch.Tensor) + self.assertEqual(out.ndim, 0) + + def test_forward_zero_weights_biases(self): + """ + Check free energy when all weights and biases are zero. + Analytic test: all weights & biases = 0 + Free energy becomes: + F(v) = - sum_j softplus(0) = -n_hidden * log(2) + """ + with torch.no_grad(): + self.rbm._weights[:] = 0 + self.rbm._visible_biases[:] = 0 + self.rbm._hidden_biases[:] = 0 + v = torch.tensor([[1.0, 0.0, 1.0, 1.0]]) # value doesn't matter + out = self.rbm.forward(v) + expected = -self.rbm.n_hidden * torch.log(torch.tensor(2.0)) + self.assertTrue(torch.allclose(out, expected)) + + def test_forward_ordering_bias(self): + """ + Free energy ordering test: + If visible_bias is very positive, visible=1 must yield + much lower free energy than visible=0. + """ + # Create a tiny RBM for easy testing + rbm = RBM(n_visible=1, n_hidden=1) + with torch.no_grad(): + rbm._weights[:] = 0 + rbm._hidden_biases[:] = 0 + rbm._visible_biases[:] = 1000.0 + + f1 = rbm.forward(torch.tensor([[1.0]])) + f0 = rbm.forward(torch.tensor([[0.0]])) + self.assertLess(f1, f0) + + def test_forward_small_numeric_case(self): + """Check free energy against manual calculation for 1 visible and 1 hidden unit.""" + rbm = RBM(n_visible=1, n_hidden=1) + with torch.no_grad(): + rbm._weights[:] = 2.0 + rbm._visible_biases[:] = 1.0 + rbm._hidden_biases[:] = -1.0 + + v = torch.tensor([[1.0]]) + # Independent manual calculation + expected = -1.0 - torch.nn.functional.softplus(torch.tensor(1.0)) + out = rbm.forward(v) + self.assertTrue(torch.allclose(out, expected)) + if __name__ == "__main__": unittest.main()