-
Notifications
You must be signed in to change notification settings - Fork 11
Add RBM #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add RBM #47
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,7 +43,7 @@ | |
| spread = AggregatedSamples.spread | ||
|
|
||
|
|
||
| __all__ = ["GraphRestrictedBoltzmannMachine"] | ||
| __all__ = ["GraphRestrictedBoltzmannMachine", "RestrictedBoltzmannMachine"] | ||
|
|
||
|
|
||
| class GraphRestrictedBoltzmannMachine(torch.nn.Module): | ||
|
|
@@ -662,3 +662,119 @@ def estimate_beta(self, spins: torch.Tensor) -> float: | |
| bqm = BinaryQuadraticModel.from_ising(*self.to_ising(1)) | ||
| beta = 1 / mple(bqm, (spins.detach().cpu().numpy(), self._nodes))[0] | ||
| return beta | ||
|
|
||
| class RestrictedBoltzmannMachine(torch.nn.Module): | ||
| """A Restricted Boltzmann Machine (RBM) model. | ||
|
|
||
| This class defines the parameterization of a binary RBM. | ||
| Training using Persistent Contrastive Divergence (PCD) must be | ||
| performed externally using separate sampler and optimizer classes. | ||
|
|
||
| Args: | ||
| n_visible (int): Number of visible units. | ||
| n_hidden (int): Number of hidden units. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| n_visible: int, | ||
| n_hidden: int, | ||
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| # Model hyperparameters | ||
| self._n_visible = n_visible | ||
| self._n_hidden = n_hidden | ||
|
|
||
| # Initialize model parameters | ||
| # initialize weights | ||
| self._weights = torch.nn.Parameter( | ||
| 0.1 * torch.randn(n_visible, n_hidden) | ||
| ) | ||
| # initialize visible units biases. | ||
| self._visible_biases = torch.nn.Parameter( | ||
| 0.5 * torch.ones(n_visible) | ||
| ) | ||
| # initialize hidden units biases. | ||
| self._hidden_biases = torch.nn.Parameter( | ||
| 0.5 * torch.ones(n_hidden) | ||
| ) | ||
|
Comment on lines
+689
to
+701
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RE Where do the 0.1, 0.5 and 0.5 values come from? Should they be hard-coded? I set those values arbitrarily for GRBM. There should be a better initialization scheme for RBMs, e.g., section 8
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These values worked for the image generation example. Sure, I can experiment with 0.01 as suggested in the guide and will let you know how it affects the performance.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've found the initialisation of the GRBM to be not great for my experiments, so I've had to pass initial linear and quadratic weights. @kevinchern in your experience, have you had to do the same? If so, should we change the default initialisation?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VolodyaCO good point. I've had similar experiences and found setting initial weights to 0 to be robust in general. Could you create an issue for this? edit: actually i'll do it now edit 2: here's the issue. please add more details as u see fit #48 |
||
|
|
||
| @property | ||
| def n_visible(self) -> int: | ||
| """Number of visible units.""" | ||
| return self._n_visible | ||
|
|
||
| @property | ||
| def n_hidden(self) -> int: | ||
| """Number of hidden units.""" | ||
| return self._n_hidden | ||
|
|
||
| @property | ||
| def weights(self) -> torch.Tensor: | ||
| """Weights of the RBM.""" | ||
| return self._weights | ||
|
|
||
| @property | ||
| def visible_biases(self) -> torch.Tensor: | ||
| """Visible biases of the RBM.""" | ||
| return self._visible_biases | ||
|
|
||
| @property | ||
| def hidden_biases(self) -> torch.Tensor: | ||
| """Hidden biases of the RBM.""" | ||
| return self._hidden_biases | ||
|
|
||
| def sample_hidden(self, visible: torch.Tensor) -> torch.Tensor: | ||
| """Sample from the distribution P(h|v). | ||
|
|
||
| Args: | ||
| visible (torch.Tensor): Tensor of shape (batch_size, n_visible) | ||
| representing the states of visible units. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing | ||
| sampled hidden units. | ||
| """ | ||
| hidden_probs = torch.sigmoid(self._hidden_biases + visible @ self._weights) | ||
| return torch.bernoulli(hidden_probs) | ||
|
|
||
| def sample_visible(self, hidden: torch.Tensor) -> torch.Tensor: | ||
| """Sample from the distribution P(v|h). | ||
|
|
||
| Args: | ||
| hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden) | ||
| representing the states of hidden units. | ||
| Returns: | ||
| torch.Tensor: Binary tensor of shape (batch_size, n_visible) representing | ||
| sampled visible units. | ||
| """ | ||
| visible_probs = torch.sigmoid(self._visible_biases + hidden @ self._weights.t()) | ||
| return torch.bernoulli(visible_probs) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| """Computes the RBM free energy of a batch of visible units averaged over the batch. | ||
|
|
||
| The free energy F(x) for a visible vector x is: | ||
|
|
||
| .. math:: | ||
| F(x) = - x · visible_biases | ||
| - sum_{j=1}^{n_hidden} log(1 + exp(hidden_biases[j] + (x · weights)_j)) | ||
|
|
||
| Args: | ||
| x (torch.Tensor): Tensor of shape (batch_size, n_visible) representing the visible layer. | ||
|
|
||
| Returns: | ||
| torch.Tensor: Scalar tensor representing the **average free energy** over the batch. | ||
| """ | ||
|
|
||
| v_term = (x * self._visible_biases).sum(dim=1) | ||
|
|
||
| hidden_pre_activation = x @ self._weights + self._hidden_biases | ||
|
|
||
| h_term = torch.sum(torch.nn.functional.softplus(hidden_pre_activation), dim=1) | ||
|
|
||
| free_energy_per_sample = -v_term - h_term | ||
|
|
||
| # average over batch | ||
| return free_energy_per_sample.mean() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,58 @@ | ||
| import torch | ||
| from dwave.plugins.torch.models.boltzmann_machine import ( | ||
| RestrictedBoltzmannMachine as RBM, | ||
| ) | ||
|
|
||
| class PCDSampler: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename to |
||
| """Persistent Contrastive Divergence (PCD) sampler for RBMs. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update docstring to reflect name change |
||
| This sampler maintains a persistent Markov chain of visible states | ||
| across minibatches and performs Gibbs sampling using the RBM’s | ||
| sampling functions. | ||
| Args: | ||
| rbm (RBM): The RBM model from which the sampler draws samples. | ||
| """ | ||
| def __init__(self, rbm: RBM): | ||
| self.rbm = rbm | ||
|
|
||
| # Stores the last visible states to initialize the Markov chain in Persistent Contrastive Divergence (PCD) | ||
| self.previous_visible_values = None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Initialize with random +/-1 values to avoid undefined conditional sampling
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would never happen because it will be initialized with
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is it not initialized here and why is it with real-valued initial values instead of +/-1 values?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That could have been possible if I had the This initialization works for RBM and recommended by Hinton I think. Using other initialization would really deteriorate the quality of the generated images.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. See this comment: we want to pack the sampling parameters into the sampler. This is aligned with the goal of having drop-in replacements for different samplers and models and #58 |
||
|
|
||
| def sample( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pack the sampling parameters into the constructor and use |
||
| self, | ||
| batch_size: int, | ||
| gibbs_steps: int, | ||
| start_visible: torch.Tensor | None = None, | ||
| ): | ||
| """Generate a sample of visible and hidden units using gibbs sampling. | ||
| Args: | ||
| batch_size (int): Number of samples to generate. | ||
| gibbs_steps (int): Number of Gibbs sampling steps to perform. | ||
| start_visible (torch.Tensor | None, optional): Initial visible states to | ||
| start the Gibbs chain (shape: [batch_size, n_visible]). If None, | ||
| a random Gaussian initialization is used. | ||
| Returns: | ||
| tuple[torch.Tensor, torch.Tensor]: A tuple of (visible, hidden) from the last Gibbs step: | ||
| - visible: (batch_size, n_visible) | ||
| - hidden: (batch_size, n_hidden) | ||
| """ | ||
| if start_visible is None: | ||
| visible_values = torch.randn( | ||
| batch_size, self.rbm.n_visible, device=self.rbm.weights.device | ||
| ) | ||
| else: | ||
| visible_values = start_visible | ||
|
|
||
| hidden_values = None | ||
|
|
||
| for _ in range(gibbs_steps): | ||
| hidden_values = self.rbm.sample_hidden(visible_values) | ||
| visible_values = self.rbm.sample_visible(hidden_values) | ||
|
|
||
| # Store samples to initialize the next Markov chain with (PCD) | ||
| self.previous_visible_values = visible_values.detach() | ||
|
|
||
| return visible_values, hidden_values | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make RBM extend
GraphRestrictedBoltzmannMachine? e.g.,There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed in one of our ML tools meeting, we decided to keep these two separate to make the RBM as efficient as possible. There is no need to materialize a graph for RBM. I would be happy to discuss this in a meeting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes that was the case early on in the development. Per yesterday's discussion, one requirement of this implementation is to have it be a drop-in replacement for other GRBM models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't remember agreeing on this :) Let's discuss this in a meeting.