From c24bbdd1c933bd33deb4ac4d3e766bc905dac673 Mon Sep 17 00:00:00 2001 From: abdela47 Date: Sat, 20 Dec 2025 11:41:11 +0400 Subject: [PATCH 1/4] Add unit tests for pseudo_kl_divergence_loss --- tests/test_pseudo_kl_divergence.py | 138 +++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 tests/test_pseudo_kl_divergence.py diff --git a/tests/test_pseudo_kl_divergence.py b/tests/test_pseudo_kl_divergence.py new file mode 100644 index 0000000..79f17b4 --- /dev/null +++ b/tests/test_pseudo_kl_divergence.py @@ -0,0 +1,138 @@ +""" +Unit tests for pseudo_kl_divergence_loss. + +These tests verify the *statistical structure* of the pseudo-KL divergence used +in the DVAE setting, not the correctness of the Boltzmann machine itself. + +In particular, we test that: +1) The loss matches the reference decomposition: + pseudo_KL = cross_entropy_with_prior - entropy_of_encoder +2) The function supports both documented spin shapes. +3) The gradient w.r.t. encoder logits behaves as expected. + +The tests intentionally use deterministic dummy Boltzmann machines to isolate +and validate the behavior of pseudo_kl_divergence_loss in isolation. +""" + +import torch +import torch.nn.functional as F + +from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss + +class DummyBoltzmannMachine: + """ + Minimal deterministic stand-in for GraphRestrictedBoltzmannMachine. + + The purpose of this class is NOT to model a real Boltzmann machine. + Instead, it provides a simple, deterministic quasi_objective so that + we can verify how pseudo_kl_divergence_loss combines its terms. + """ + + def quasi_objective(self, spins: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: + """ + Return a deterministic scalar depending on spins and samples. + + Using a simple mean ensures: + - deterministic behavior + - no dependency on logits + - gradients w.r.t. logits come only from the entropy term + """ + return spins.float().mean() + samples.float().mean() + + +def test_pseudo_kl_matches_reference_2d_spins(): + """ + Verify that pseudo_kl_divergence_loss matches the reference formula + for 2D spins of shape (batch_size, n_spins). + + This test directly reconstructs the loss as: + cross_entropy - entropy + and checks numerical equality. + """ + + bm = DummyBoltzmannMachine() + + batch, n_spins = 4, 6 + logits = torch.linspace(-2.0, 2.0, steps=batch * n_spins).reshape(batch, n_spins) + + # spins: (batch_size, n_spins) + spins = torch.tensor( + [[-1,1,-1, 1,-1, 1], + [1,-1, 1,-1, 1,-1], + [-1,-1,1,1,-1,1], + [1,1,-1,-1,1,-1]], + dtype=torch.float32 + ) + + samples = torch.ones(batch, n_spins, dtype=torch.float32) + + out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + + probs = torch.sigmoid(logits) + entropy = F.binary_cross_entropy_with_logits(logits, probs) + cross_binary = bm.quasi_objective(spins, samples) + ref = cross_binary - entropy + + torch.testing.assert_close(out, ref) + +def test_pseudo_kl_works_with_3d_spins(): + """ + Verify that pseudo_kl_divergence_loss supports 3D spins of shape: + (batch_size, n_samples, n_spins) + + as documented in the function docstring. + """ + bm = DummyBoltzmannMachine() + + batch, n_samples, n_spins = 3, 5, 4 + logits = torch.zeros(batch, n_spins) + + # spins: (batch_size, n_samples, n_spins) + spins = torch.ones(batch, n_samples, n_spins) + samples = torch.zeros(batch, n_spins) + + out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + + probs = torch.sigmoid(logits) + entropy = F.binary_cross_entropy_with_logits(logits, probs) + cross_binary = bm.quasi_objective(spins, samples) + + torch.testing.assert_close(out, cross_binary - entropy) + +def test_pseudo_kl_gradient_matches_negative_entropy_when_cross_entropy_constant(): + """ + Verify gradient behavior of pseudo_kl_divergence_loss. + + If the Boltzmann machine quasi_objective returns a constant value, + then the loss gradient w.r.t. logits must come entirely from the + negative entropy term. + + This test ensures that pseudo_kl_divergence_loss applies the correct + statistical pressure on encoder logits. + """ + + class ConstantObjectiveBM: + def quasi_objective(self, spins: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: + # Constant => contributes no gradient wrt logits + return torch.tensor(1.2345, dtype=spins.dtype, device=spins.device) + + bm = ConstantObjectiveBM() + + batch, n_spins = 2, 3 + + logits = torch.randn(batch, n_spins, requires_grad = True) + spins = torch.ones(batch, n_spins) + samples = torch.zeros(batch, n_spins) + + out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + + out.backward() + + # reference: gradient should be gradient of (-entropy) + logits2 = logits.detach().clone().requires_grad_(True) + # note: require_grad is a property so require_grad_ is used to modify in place + probs2 = torch.sigmoid(logits2) + entropy2 = F.binary_cross_entropy_with_logits(logits2, probs2) + (-entropy2).backward() + + torch.testing.assert_close(logits.grad, logits2.grad) From 64841de023204959d5ef750a41214ff2956d3eb0 Mon Sep 17 00:00:00 2001 From: abdela47 Date: Tue, 23 Dec 2025 17:11:35 +0400 Subject: [PATCH 2/4] Refine pseudo_kl_divergence_loss tests --- tests/test_pseudo_kl_divergence.py | 109 ++++++++++++++--------------- 1 file changed, 54 insertions(+), 55 deletions(-) diff --git a/tests/test_pseudo_kl_divergence.py b/tests/test_pseudo_kl_divergence.py index 79f17b4..59afc56 100644 --- a/tests/test_pseudo_kl_divergence.py +++ b/tests/test_pseudo_kl_divergence.py @@ -19,89 +19,83 @@ from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss + class DummyBoltzmannMachine: - """ - Minimal deterministic stand-in for GraphRestrictedBoltzmannMachine. + """A minimal and deterministic stand-in for GraphRestrictedBoltzmannMachine. The purpose of this class is NOT to model a real Boltzmann machine. Instead, it provides a simple, deterministic quasi_objective so that we can verify how pseudo_kl_divergence_loss combines its terms. """ - def quasi_objective(self, spins: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: - """ - Return a deterministic scalar depending on spins and samples. - - Using a simple mean ensures: - - deterministic behavior - - no dependency on logits - - gradients w.r.t. logits come only from the entropy term + def quasi_objective(self, spins_data: torch.Tensor, spins_model: torch.Tensor) -> torch.Tensor: + """Return a deterministic scalar representing a positive-minus-negative phase + objective, independent of encoder logits. """ - return spins.float().mean() + samples.float().mean() + return spins_data.float().mean() - spins_model.float().mean() def test_pseudo_kl_matches_reference_2d_spins(): - """ - Verify that pseudo_kl_divergence_loss matches the reference formula - for 2D spins of shape (batch_size, n_spins). - - This test directly reconstructs the loss as: - cross_entropy - entropy - and checks numerical equality. - """ + """Match explicit cross-entropy minus entropy reference for 2D spins.""" bm = DummyBoltzmannMachine() - batch, n_spins = 4, 6 - logits = torch.linspace(-2.0, 2.0, steps=batch * n_spins).reshape(batch, n_spins) + batch_size, n_spins = 4, 6 + logits = torch.linspace(-2.0, 2.0, steps=batch_size * n_spins).reshape(batch_size, n_spins) - # spins: (batch_size, n_spins) - spins = torch.tensor( - [[-1,1,-1, 1,-1, 1], - [1,-1, 1,-1, 1,-1], - [-1,-1,1,1,-1,1], - [1,1,-1,-1,1,-1]], + # spins_data: (batch_size, n_spins) + spins_data = torch.tensor( + [[-1, 1,-1, 1,-1, 1], + [ 1,-1, 1,-1, 1,-1], + [-1,-1, 1, 1,-1, 1], + [ 1, 1,-1,-1, 1,-1]], dtype=torch.float32 ) - samples = torch.ones(batch, n_spins, dtype=torch.float32) + spins_model = torch.ones(batch_size, n_spins, dtype=torch.float32) - out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) probs = torch.sigmoid(logits) entropy = F.binary_cross_entropy_with_logits(logits, probs) - cross_binary = bm.quasi_objective(spins, samples) - ref = cross_binary - entropy + cross_entropy = bm.quasi_objective(spins_data, spins_model) + ref = cross_entropy - entropy torch.testing.assert_close(out, ref) -def test_pseudo_kl_works_with_3d_spins(): - """ - Verify that pseudo_kl_divergence_loss supports 3D spins of shape: - (batch_size, n_samples, n_spins) - as documented in the function docstring. - """ +def test_pseudo_kl_supports_3d_spin_shape(): + """Support 3D spins of shape (batch_size, n_samples, n_spins) as documented.""" bm = DummyBoltzmannMachine() - batch, n_samples, n_spins = 3, 5, 4 - logits = torch.zeros(batch, n_spins) + batch_size, n_samples, n_spins = 3, 5, 4 + logits = torch.zeros(batch_size, n_spins) # spins: (batch_size, n_samples, n_spins) - spins = torch.ones(batch, n_samples, n_spins) - samples = torch.zeros(batch, n_spins) + spins_data = torch.ones(batch_size, n_samples, n_spins) + spins_model = torch.zeros(batch_size, n_spins) - out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) probs = torch.sigmoid(logits) entropy = F.binary_cross_entropy_with_logits(logits, probs) - cross_binary = bm.quasi_objective(spins, samples) + cross_entropy = bm.quasi_objective(spins_data, spins_model) + + torch.testing.assert_close(out, cross_entropy - entropy) - torch.testing.assert_close(out, cross_binary - entropy) def test_pseudo_kl_gradient_matches_negative_entropy_when_cross_entropy_constant(): - """ - Verify gradient behavior of pseudo_kl_divergence_loss. + """Verify gradient behavior of pseudo_kl_divergence_loss. If the Boltzmann machine quasi_objective returns a constant value, then the loss gradient w.r.t. logits must come entirely from the @@ -112,25 +106,30 @@ def test_pseudo_kl_gradient_matches_negative_entropy_when_cross_entropy_constant """ class ConstantObjectiveBM: - def quasi_objective(self, spins: torch.Tensor, samples: torch.Tensor) -> torch.Tensor: + def quasi_objective(self, spins_data: torch.Tensor, + spins_model: torch.Tensor) -> torch.Tensor: # Constant => contributes no gradient wrt logits - return torch.tensor(1.2345, dtype=spins.dtype, device=spins.device) + return torch.tensor(1.2345, dtype=spins_data.dtype, device=spins_data.device) bm = ConstantObjectiveBM() - batch, n_spins = 2, 3 + batch_size, n_spins = 2, 3 - logits = torch.randn(batch, n_spins, requires_grad = True) - spins = torch.ones(batch, n_spins) - samples = torch.zeros(batch, n_spins) + logits = torch.randn(batch_size, n_spins, requires_grad=True) + spins_data = torch.ones(batch_size, n_spins) + spins_model = torch.zeros(batch_size, n_spins) - out = pseudo_kl_divergence_loss(spins=spins, logits=logits, samples=samples, boltzmann_machine=bm) + out = pseudo_kl_divergence_loss( + spins=spins_data, + logits=logits, + samples=spins_model, + boltzmann_machine=bm + ) out.backward() - # reference: gradient should be gradient of (-entropy) + # reference gradient from -entropy only logits2 = logits.detach().clone().requires_grad_(True) - # note: require_grad is a property so require_grad_ is used to modify in place probs2 = torch.sigmoid(logits2) entropy2 = F.binary_cross_entropy_with_logits(logits2, probs2) (-entropy2).backward() From 54f2fdc5455e8690e9ac1e46bba9fc137004e394 Mon Sep 17 00:00:00 2001 From: Ahmed Abdelaziz <71100796+abdela47@users.noreply.github.com> Date: Tue, 23 Dec 2025 17:18:25 +0400 Subject: [PATCH 3/4] Copyright & Apache License Reviewer Suggestion Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com> --- tests/test_pseudo_kl_divergence.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/test_pseudo_kl_divergence.py b/tests/test_pseudo_kl_divergence.py index 59afc56..1584cec 100644 --- a/tests/test_pseudo_kl_divergence.py +++ b/tests/test_pseudo_kl_divergence.py @@ -1,3 +1,16 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. """ Unit tests for pseudo_kl_divergence_loss. From 7e361bfce0003bdadb971e5150eba4059ccd30a4 Mon Sep 17 00:00:00 2001 From: abdela47 Date: Tue, 23 Dec 2025 17:30:28 +0400 Subject: [PATCH 4/4] updated the 2D test to derive batch_size and n_spins directly from spins_data.shape for clarity and consistency --- tests/test_pseudo_kl_divergence.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_pseudo_kl_divergence.py b/tests/test_pseudo_kl_divergence.py index 1584cec..98ebddb 100644 --- a/tests/test_pseudo_kl_divergence.py +++ b/tests/test_pseudo_kl_divergence.py @@ -53,9 +53,6 @@ def test_pseudo_kl_matches_reference_2d_spins(): bm = DummyBoltzmannMachine() - batch_size, n_spins = 4, 6 - logits = torch.linspace(-2.0, 2.0, steps=batch_size * n_spins).reshape(batch_size, n_spins) - # spins_data: (batch_size, n_spins) spins_data = torch.tensor( [[-1, 1,-1, 1,-1, 1], @@ -65,6 +62,10 @@ def test_pseudo_kl_matches_reference_2d_spins(): dtype=torch.float32 ) + batch_size, n_spins = spins_data.shape + logits = torch.linspace(-2.0, 2.0, steps=batch_size * n_spins).reshape(batch_size, n_spins) + + spins_model = torch.ones(batch_size, n_spins, dtype=torch.float32) out = pseudo_kl_divergence_loss(