-
Notifications
You must be signed in to change notification settings - Fork 11
Add unit tests for pseudo_kl_divergence_loss #59
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
abdela47
wants to merge
4
commits into
dwavesystems:main
Choose a base branch
from
abdela47:test/pseudo-kl-unit-tests
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
c24bbdd
Add unit tests for pseudo_kl_divergence_loss
abdela47 64841de
Refine pseudo_kl_divergence_loss tests
abdela47 54f2fdc
Copyright & Apache License Reviewer Suggestion
abdela47 7e361bf
updated the 2D test to derive batch_size and n_spins directly from sp…
abdela47 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,151 @@ | ||||||||||||
| # Copyright 2025 D-Wave | ||||||||||||
| # | ||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||
| # You may obtain a copy of the License at | ||||||||||||
| # | ||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||
| # | ||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||
| # limitations under the License. | ||||||||||||
| """ | ||||||||||||
| Unit tests for pseudo_kl_divergence_loss. | ||||||||||||
|
|
||||||||||||
| These tests verify the *statistical structure* of the pseudo-KL divergence used | ||||||||||||
| in the DVAE setting, not the correctness of the Boltzmann machine itself. | ||||||||||||
|
|
||||||||||||
| In particular, we test that: | ||||||||||||
| 1) The loss matches the reference decomposition: | ||||||||||||
| pseudo_KL = cross_entropy_with_prior - entropy_of_encoder | ||||||||||||
| 2) The function supports both documented spin shapes. | ||||||||||||
| 3) The gradient w.r.t. encoder logits behaves as expected. | ||||||||||||
|
|
||||||||||||
| The tests intentionally use deterministic dummy Boltzmann machines to isolate | ||||||||||||
| and validate the behavior of pseudo_kl_divergence_loss in isolation. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| import torch | ||||||||||||
| import torch.nn.functional as F | ||||||||||||
|
|
||||||||||||
| from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| class DummyBoltzmannMachine: | ||||||||||||
| """A minimal and deterministic stand-in for GraphRestrictedBoltzmannMachine. | ||||||||||||
|
|
||||||||||||
| The purpose of this class is NOT to model a real Boltzmann machine. | ||||||||||||
| Instead, it provides a simple, deterministic quasi_objective so that | ||||||||||||
| we can verify how pseudo_kl_divergence_loss combines its terms. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| def quasi_objective(self, spins_data: torch.Tensor, spins_model: torch.Tensor) -> torch.Tensor: | ||||||||||||
| """Return a deterministic scalar representing a positive-minus-negative phase | ||||||||||||
| objective, independent of encoder logits. | ||||||||||||
| """ | ||||||||||||
| return spins_data.float().mean() - spins_model.float().mean() | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_pseudo_kl_matches_reference_2d_spins(): | ||||||||||||
| """Match explicit cross-entropy minus entropy reference for 2D spins.""" | ||||||||||||
|
|
||||||||||||
| bm = DummyBoltzmannMachine() | ||||||||||||
|
|
||||||||||||
| # spins_data: (batch_size, n_spins) | ||||||||||||
| spins_data = torch.tensor( | ||||||||||||
| [[-1, 1,-1, 1,-1, 1], | ||||||||||||
| [ 1,-1, 1,-1, 1,-1], | ||||||||||||
| [-1,-1, 1, 1,-1, 1], | ||||||||||||
| [ 1, 1,-1,-1, 1,-1]], | ||||||||||||
| dtype=torch.float32 | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| batch_size, n_spins = spins_data.shape | ||||||||||||
| logits = torch.linspace(-2.0, 2.0, steps=batch_size * n_spins).reshape(batch_size, n_spins) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| spins_model = torch.ones(batch_size, n_spins, dtype=torch.float32) | ||||||||||||
|
Comment on lines
+67
to
+69
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||
|
|
||||||||||||
| out = pseudo_kl_divergence_loss( | ||||||||||||
| spins=spins_data, | ||||||||||||
| logits=logits, | ||||||||||||
| samples=spins_model, | ||||||||||||
| boltzmann_machine=bm | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| probs = torch.sigmoid(logits) | ||||||||||||
| entropy = F.binary_cross_entropy_with_logits(logits, probs) | ||||||||||||
| cross_entropy = bm.quasi_objective(spins_data, spins_model) | ||||||||||||
| ref = cross_entropy - entropy | ||||||||||||
|
|
||||||||||||
| torch.testing.assert_close(out, ref) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_pseudo_kl_supports_3d_spin_shape(): | ||||||||||||
| """Support 3D spins of shape (batch_size, n_samples, n_spins) as documented.""" | ||||||||||||
| bm = DummyBoltzmannMachine() | ||||||||||||
|
|
||||||||||||
| batch_size, n_samples, n_spins = 3, 5, 4 | ||||||||||||
| logits = torch.zeros(batch_size, n_spins) | ||||||||||||
|
|
||||||||||||
| # spins: (batch_size, n_samples, n_spins) | ||||||||||||
| spins_data = torch.ones(batch_size, n_samples, n_spins) | ||||||||||||
| spins_model = torch.zeros(batch_size, n_spins) | ||||||||||||
|
|
||||||||||||
| out = pseudo_kl_divergence_loss( | ||||||||||||
| spins=spins_data, | ||||||||||||
| logits=logits, | ||||||||||||
| samples=spins_model, | ||||||||||||
| boltzmann_machine=bm | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| probs = torch.sigmoid(logits) | ||||||||||||
| entropy = F.binary_cross_entropy_with_logits(logits, probs) | ||||||||||||
| cross_entropy = bm.quasi_objective(spins_data, spins_model) | ||||||||||||
|
|
||||||||||||
| torch.testing.assert_close(out, cross_entropy - entropy) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def test_pseudo_kl_gradient_matches_negative_entropy_when_cross_entropy_constant(): | ||||||||||||
| """Verify gradient behavior of pseudo_kl_divergence_loss. | ||||||||||||
|
|
||||||||||||
| If the Boltzmann machine quasi_objective returns a constant value, | ||||||||||||
| then the loss gradient w.r.t. logits must come entirely from the | ||||||||||||
| negative entropy term. | ||||||||||||
|
|
||||||||||||
| This test ensures that pseudo_kl_divergence_loss applies the correct | ||||||||||||
| statistical pressure on encoder logits. | ||||||||||||
| """ | ||||||||||||
|
|
||||||||||||
| class ConstantObjectiveBM: | ||||||||||||
| def quasi_objective(self, spins_data: torch.Tensor, | ||||||||||||
| spins_model: torch.Tensor) -> torch.Tensor: | ||||||||||||
| # Constant => contributes no gradient wrt logits | ||||||||||||
| return torch.tensor(1.2345, dtype=spins_data.dtype, device=spins_data.device) | ||||||||||||
|
|
||||||||||||
| bm = ConstantObjectiveBM() | ||||||||||||
|
|
||||||||||||
| batch_size, n_spins = 2, 3 | ||||||||||||
|
|
||||||||||||
| logits = torch.randn(batch_size, n_spins, requires_grad=True) | ||||||||||||
| spins_data = torch.ones(batch_size, n_spins) | ||||||||||||
| spins_model = torch.zeros(batch_size, n_spins) | ||||||||||||
|
|
||||||||||||
| out = pseudo_kl_divergence_loss( | ||||||||||||
| spins=spins_data, | ||||||||||||
| logits=logits, | ||||||||||||
| samples=spins_model, | ||||||||||||
| boltzmann_machine=bm | ||||||||||||
| ) | ||||||||||||
|
|
||||||||||||
| out.backward() | ||||||||||||
|
|
||||||||||||
| # reference gradient from -entropy only | ||||||||||||
| logits2 = logits.detach().clone().requires_grad_(True) | ||||||||||||
| probs2 = torch.sigmoid(logits2) | ||||||||||||
| entropy2 = F.binary_cross_entropy_with_logits(logits2, probs2) | ||||||||||||
| (-entropy2).backward() | ||||||||||||
|
|
||||||||||||
| torch.testing.assert_close(logits.grad, logits2.grad) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't forget to add an if __name__ == "__main__":
unittest.main()at the end of this file to make sure the tests are run correctly. |
||||||||||||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.