1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import Optional
1615
1716import torch
18- from dimod import Sampler
1917
2018from dwave .plugins .torch .models .boltzmann_machine import GraphRestrictedBoltzmannMachine
2119
2523def pseudo_kl_divergence_loss (
2624 spins : torch .Tensor ,
2725 logits : torch .Tensor ,
26+ samples : torch .Tensor ,
2827 boltzmann_machine : GraphRestrictedBoltzmannMachine ,
29- sampler : Sampler ,
30- sample_kwargs : dict ,
31- prefactor : Optional [float ] = None ,
32- linear_range : Optional [tuple [float , float ]] = None ,
33- quadratic_range : Optional [tuple [float , float ]] = None ,
3428):
3529 """A pseudo Kullback-Leibler divergence loss function for a discrete autoencoder with a
3630 Boltzmann machine prior.
@@ -47,32 +41,11 @@ def pseudo_kl_divergence_loss(
4741 logits are the raw output of the encoder.
4842 boltzmann_machine (GraphRestrictedBoltzmannMachine): An instance of a Boltzmann
4943 machine.
50- sampler (Sampler): A sampler used for generating samples.
51- sample_kwargs (dict): Additional keyword arguments for the ``sampler.sample``
52- method.
53- prefactor (float, optional): A scaling applied to the Hamiltonian weights
54- (linear and quadratic weights). When None, no scaling is applied. Defaults
55- to None.
56- linear_range (tuple[float, float], optional): Linear weights are clipped to
57- ``linear_range`` prior to sampling. This clipping occurs after the
58- ``prefactor`` scaling has been applied. When None, no clipping is applied.
59- Defaults to None.
60- quadratic_range (tuple[float, float], optional): Quadratic weights are clipped
61- to ``quadratic_range`` prior to sampling. This clipping occurs after the
62- ``prefactor`` scaling has been applied. When None, no clipping is applied.
63- Defaults to None.
44+ samples (torch.Tensor): A tensor of samples from the Boltzmann machine.
6445
6546 Returns:
6647 torch.Tensor: The computed pseudo KL divergence loss.
6748 """
68- samples = boltzmann_machine .sample (
69- sampler = sampler ,
70- device = spins .device ,
71- prefactor = prefactor if prefactor is not None else 1.0 ,
72- linear_range = linear_range ,
73- quadratic_range = quadratic_range ,
74- sample_params = sample_kwargs ,
75- )
7649 probabilities = torch .sigmoid (logits )
7750 entropy = torch .nn .functional .binary_cross_entropy_with_logits (logits , probabilities )
7851 cross_entropy = boltzmann_machine .quasi_objective (spins , samples )
0 commit comments