Skip to content

Commit 49c2fb0

Browse files
VolodyaCOVladimir Vargas Calderón
andauthored
Decouple sampling from pseudo_kl_divergence_loss. (#34)
Co-authored-by: Vladimir Vargas Calderón <vvargasc@dwavesys.com>
1 parent 6d59b84 commit 49c2fb0

File tree

3 files changed

+16
-31
lines changed

3 files changed

+16
-31
lines changed

dwave/plugins/torch/models/losses/kl_divergence.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Optional
1615

1716
import torch
18-
from dimod import Sampler
1917

2018
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
2119

@@ -25,12 +23,8 @@
2523
def pseudo_kl_divergence_loss(
2624
spins: torch.Tensor,
2725
logits: torch.Tensor,
26+
samples: torch.Tensor,
2827
boltzmann_machine: GraphRestrictedBoltzmannMachine,
29-
sampler: Sampler,
30-
sample_kwargs: dict,
31-
prefactor: Optional[float] = None,
32-
linear_range: Optional[tuple[float, float]] = None,
33-
quadratic_range: Optional[tuple[float, float]] = None,
3428
):
3529
"""A pseudo Kullback-Leibler divergence loss function for a discrete autoencoder with a
3630
Boltzmann machine prior.
@@ -47,32 +41,11 @@ def pseudo_kl_divergence_loss(
4741
logits are the raw output of the encoder.
4842
boltzmann_machine (GraphRestrictedBoltzmannMachine): An instance of a Boltzmann
4943
machine.
50-
sampler (Sampler): A sampler used for generating samples.
51-
sample_kwargs (dict): Additional keyword arguments for the ``sampler.sample``
52-
method.
53-
prefactor (float, optional): A scaling applied to the Hamiltonian weights
54-
(linear and quadratic weights). When None, no scaling is applied. Defaults
55-
to None.
56-
linear_range (tuple[float, float], optional): Linear weights are clipped to
57-
``linear_range`` prior to sampling. This clipping occurs after the
58-
``prefactor`` scaling has been applied. When None, no clipping is applied.
59-
Defaults to None.
60-
quadratic_range (tuple[float, float], optional): Quadratic weights are clipped
61-
to ``quadratic_range`` prior to sampling. This clipping occurs after the
62-
``prefactor`` scaling has been applied. When None, no clipping is applied.
63-
Defaults to None.
44+
samples (torch.Tensor): A tensor of samples from the Boltzmann machine.
6445
6546
Returns:
6647
torch.Tensor: The computed pseudo KL divergence loss.
6748
"""
68-
samples = boltzmann_machine.sample(
69-
sampler=sampler,
70-
device=spins.device,
71-
prefactor=prefactor if prefactor is not None else 1.0,
72-
linear_range=linear_range,
73-
quadratic_range=quadratic_range,
74-
sample_params=sample_kwargs,
75-
)
7649
probabilities = torch.sigmoid(logits)
7750
entropy = torch.nn.functional.binary_cross_entropy_with_logits(logits, probabilities)
7851
cross_entropy = boltzmann_machine.quasi_objective(spins, samples)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
---
2+
upgrade:
3+
- |
4+
``dwave.plugins.torch.models.losses.kl_divergence.pseudo_kl_divergence_loss``
5+
no longer uses the Graph-Restricted Boltzmann Machine to generate Boltzmann
6+
samples internally. Instead, the samples must be provided as an argument
7+
to the function. This is a breaking change.

tests/test_dvae_winci2020.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,12 +129,17 @@ def test_train(self, n_latent_dims):
129129

130130
discretes = discretes.reshape(discretes.shape[0], -1)
131131
latents = latents.reshape(latents.shape[0], -1)
132+
samples = self.boltzmann_machine.sample(
133+
self.sampler_sa,
134+
as_tensor=True,
135+
prefactor=1.0,
136+
sample_params=dict(num_sweeps=10, seed=1234, num_reads=100),
137+
)
132138
kl_loss = pseudo_kl_divergence_loss(
133139
discretes,
134140
latents,
141+
samples,
135142
self.boltzmann_machine,
136-
self.sampler_sa,
137-
dict(num_sweeps=10, seed=1234, num_reads=100),
138143
)
139144
loss = loss + 1e-1 * kl_loss
140145
optimiser.zero_grad()

0 commit comments

Comments
 (0)