-
Notifications
You must be signed in to change notification settings - Fork 12
Add RBM #47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add RBM #47
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,4 +66,10 @@ venv.bak/ | |
| dmypy.json | ||
|
|
||
| # aim | ||
| *.aim* | ||
| *.aim* | ||
|
|
||
| # Datasets | ||
| data/* | ||
|
|
||
| # Generated images | ||
| samples/* | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -43,7 +43,7 @@ | |||||||||||||||||||||||||||||
| spread = AggregatedSamples.spread | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| __all__ = ["GraphRestrictedBoltzmannMachine"] | ||||||||||||||||||||||||||||||
| __all__ = ["GraphRestrictedBoltzmannMachine", "RestrictedBoltzmannMachine"] | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class GraphRestrictedBoltzmannMachine(torch.nn.Module): | ||||||||||||||||||||||||||||||
|
|
@@ -662,3 +662,277 @@ def estimate_beta(self, spins: torch.Tensor) -> float: | |||||||||||||||||||||||||||||
| bqm = BinaryQuadraticModel.from_ising(*self.to_ising(1)) | ||||||||||||||||||||||||||||||
| beta = 1 / mple(bqm, (spins.detach().cpu().numpy(), self._nodes))[0] | ||||||||||||||||||||||||||||||
| return beta | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| class RestrictedBoltzmannMachine(torch.nn.Module): | ||||||||||||||||||||||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make RBM extend
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we discussed in one of our ML tools meeting, we decided to keep these two separate to make the RBM as efficient as possible. There is no need to materialize a graph for RBM. I would be happy to discuss this in a meeting.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that was the case early on in the development. Per yesterday's discussion, one requirement of this implementation is to have it be a drop-in replacement for other GRBM models.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't remember agreeing on this :) Let's discuss this in a meeting. |
||||||||||||||||||||||||||||||
| """A Restricted Boltzmann Machine (RBM) model. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| This class defines the parameterization and inference of a binary RBM. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| Training is performed using Persistent Contrastive Divergence (PCD). | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| n_visible (int): Number of visible units. | ||||||||||||||||||||||||||||||
| n_hidden (int): Number of hidden units. | ||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def __init__( | ||||||||||||||||||||||||||||||
| self, | ||||||||||||||||||||||||||||||
| n_visible: int, | ||||||||||||||||||||||||||||||
| n_hidden: int, | ||||||||||||||||||||||||||||||
| ) -> None: | ||||||||||||||||||||||||||||||
| super().__init__() | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Model hyperparameters | ||||||||||||||||||||||||||||||
| self._n_visible = n_visible | ||||||||||||||||||||||||||||||
| self._n_hidden = n_hidden | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Initialize model parameters | ||||||||||||||||||||||||||||||
| # initialize weights | ||||||||||||||||||||||||||||||
| self._weights = torch.nn.Parameter( | ||||||||||||||||||||||||||||||
| 0.1 * torch.randn(n_visible, n_hidden) | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # initialize visible units biases. | ||||||||||||||||||||||||||||||
| self._visible_biases = torch.nn.Parameter( | ||||||||||||||||||||||||||||||
| 0.5 * torch.ones(n_visible) | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
| # initialize hidden units biases. | ||||||||||||||||||||||||||||||
| self._hidden_biases = torch.nn.Parameter( | ||||||||||||||||||||||||||||||
| 0.5 * torch.ones(n_hidden) | ||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||
|
Comment on lines
+689
to
+701
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where do the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. RE Where do the 0.1, 0.5 and 0.5 values come from? Should they be hard-coded? I set those values arbitrarily for GRBM. There should be a better initialization scheme for RBMs, e.g., section 8
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These values worked for the image generation example. Sure, I can experiment with 0.01 as suggested in the guide and will let you know how it affects the performance.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've found the initialisation of the GRBM to be not great for my experiments, so I've had to pass initial linear and quadratic weights. @kevinchern in your experience, have you had to do the same? If so, should we change the default initialisation?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VolodyaCO good point. I've had similar experiences and found setting initial weights to 0 to be robust in general. Could you create an issue for this? edit: actually i'll do it now edit 2: here's the issue. please add more details as u see fit #48 |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Stores the last visible states to initialize the Markov chain in Persistent Contrastive Divergence (PCD) | ||||||||||||||||||||||||||||||
| self.register_buffer("_previous_visible_values", None) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| # Initialize momenta tensors for momentum-based updates (all start at 0) | ||||||||||||||||||||||||||||||
| self.register_buffer("_weight_momenta", torch.zeros(n_visible, n_hidden)) | ||||||||||||||||||||||||||||||
| self.register_buffer("_visible_bias_momenta", torch.zeros(n_visible)) | ||||||||||||||||||||||||||||||
| self.register_buffer("_hidden_bias_momenta", torch.zeros(n_hidden)) | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def n_visible(self) -> int: | ||||||||||||||||||||||||||||||
| """Number of visible units.""" | ||||||||||||||||||||||||||||||
| return self._n_visible | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def n_hidden(self) -> int: | ||||||||||||||||||||||||||||||
| """Number of hidden units.""" | ||||||||||||||||||||||||||||||
| return self._n_hidden | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def weights(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Weights of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._weights | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def visible_biases(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Visible biases of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._visible_biases | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def hidden_biases(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Hidden biases of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._hidden_biases | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def previous_visible_values(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Previous visible values used in Persistent Contrastive Divergence (PCD).""" | ||||||||||||||||||||||||||||||
| return self._previous_visible_values | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def weight_momenta(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Weight momenta of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._weight_momenta | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def visible_bias_momenta(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Visible bias momenta of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._visible_bias_momenta | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| @property | ||||||||||||||||||||||||||||||
| def hidden_bias_momenta(self) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Hidden bias momenta of the RBM.""" | ||||||||||||||||||||||||||||||
| return self._hidden_bias_momenta | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| def _sample_hidden(self, visible: torch.Tensor) -> torch.Tensor: | ||||||||||||||||||||||||||||||
| """Sample from the distribution P(h|v). | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||
| visible (torch.Tensor): Tensor of shape (batch_size, n_visible) | ||||||||||||||||||||||||||||||
| representing the states of visible units. | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||
| torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing | ||||||||||||||||||||||||||||||
| sampled hidden units. | ||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||
| Args: | |
| visible (torch.Tensor): Tensor of shape (batch_size, n_visible) | |
| representing the states of visible units. | |
| Returns: | |
| torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing | |
| sampled hidden units. | |
| Args: | |
| visible (torch.Tensor): Tensor of shape (batch_size, n_visible) | |
| representing the states of visible units. | |
| Returns: | |
| torch.Tensor: Binary tensor of shape (batch_size, n_hidden) representing | |
| sampled hidden units. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden) | |
| representing the states of hidden units. | |
| hidden (torch.Tensor): Tensor of shape (batch_size, n_hidden) | |
| representing the states of hidden units. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be named just sample instead, to conform with the GRBM class?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If start_visible != None then batch_size isn't required, right? You could make that optional as well unless there's a reasonable default value to use (e.g., batch_size=1).
Similarly, would it makes sense having gibbs_setps default to 1? I noticed that a test was using
hidden = RBM._sample_hidden()which could in that case be written as
_, hidden = RBM.generate_sample()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding batch_size, you're right. I'll make it optional.
Regarding gibbs_steps, I’d prefer not to set a default value. I want users to make an explicit choice rather than unknowingly relying on a default of 1 (as often times we need more steps for our experiments). That test example just shows using 1 step with generate_sample is like generating with one _sample_hidden call.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
VolodyaCO marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be a public method. Also, from the docstrings, it was difficult for me to infer how to use this while training because of epoch and n_epochs (it isn't clear how this information is used: to compute a decayed learning rate). I think it would be a good addition to have an example in the docstring, something like
for epoch in range(n_epochs):
for batch in dataloader:
rbm.contrastive_divergence(batch, epoch, n_gibbs_steps, learning_rate, momentum_coefficient, weight_decay, n_epochs)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right. I will make it public. And, I guess it's good to add this to the docstring or what about referring to the example in rbm_image_generation.py where it's used in a real example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can have both references, like a quick use-it-like-this, and a reference to the image generation example too.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | |
| Perform one step of Contrastive Divergence (CD-k) with momentum and weight decay. | |
| Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states | |
| for Gibbs sampling across batches. | |
| """Perform one step of Contrastive Divergence (CD-k) with momentum and weight decay. | |
| Uses Persistent Contrastive Divergence (PCD) by maintaining the last visible states | |
| for Gibbs sampling across batches. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't all calculations in this method be wrapped in a no grad context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, the only part that is necessary to be in torch.no_grad is the parameters updates part. The rest can also be in it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't sampling visible from hidden and hidden from visible also trigger gradient tracking?
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, but just error makes it sound like an actual error is being returned, not data.
| error = torch.sum(torch.abs(batch - reconstruction)) | |
| return error | |
| reconstruction_error = torch.sum(torch.abs(batch - reconstruction)) | |
| return reconstruction_error |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To keep the same as the GRBM.
| def forward(self, visible: torch.Tensor) -> torch.Tensor: | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| """ | |
| Computes the RBM free energy of a batch of visible units averaged over the batch. | |
| """Computes the RBM free energy of a batch of visible units averaged over the batch. |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| F(visible) = - visible · visible_biases | |
| - sum_{j=1}^{n_hidden} log(1 + exp(hidden_biases[j] + (visible · weights)_j)) | |
| .. math:: | |
| F(visible) = - visible · visible_biases | |
| - sum_{j=1}^{n_hidden} log(1 + exp(hidden_biases[j] + (visible · weights)_j)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure about adding so generic folders to the gitignore. Are these only created when running the examples? In that case I'd either leave it up to the developer not to commit these or put them e.g., in
examples/_data/andexamples/_samples/.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this directory is only created when running the examples. I wasn’t planning to add it to the .gitignore either, but I included it to get your input during the review. I’ll remove it then — thanks for the feedback!