Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 1 addition & 79 deletions dwave/plugins/torch/models/boltzmann_machine.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import torch

if TYPE_CHECKING:
from dwave.plugins.torch.samplers.dimod_sampler import DimodSampler
from dimod import Sampler, SampleSet

from dimod import BinaryQuadraticModel
from hybrid.composers import AggregatedSamples

from dwave.plugins.torch.utils import sampleset_to_tensor
from dwave.system.temperatures import maximum_pseudolikelihood_temperature as mple

spread = AggregatedSamples.spread
Expand Down Expand Up @@ -261,84 +261,6 @@ def theta(self) -> torch.Tensor:
by the model's input ``nodes`` and ``edges``."""
return torch.cat([self._linear, self._quadratic])

@overload
def sample(self, sampler: Sampler, as_tensor: Literal[True], **kwargs) -> torch.Tensor: ...

@overload
def sample(self, sampler: Sampler, as_tensor: Literal[False], **kwargs) -> SampleSet: ...

def sample(
self,
sampler: Sampler,
*,
prefactor: float,
linear_range: Optional[tuple[float, float]] = None,
quadratic_range: Optional[tuple[float, float]] = None,
device: Optional[torch.device] = None,
sample_params: Optional[dict] = None,
as_tensor: bool = True,
) -> Union[torch.Tensor, SampleSet]:
"""Sample from the Boltzmann machine.

This method samples and converts a sample of spins to tensors and ensures they
are not aggregated---provided the aggregation information is retained in the
sample set.

Args:
sampler (Sampler): The sampler used to sample from the model.
prefactor (float): The prefactor for which the Hamiltonian is scaled by.
This quantity is typically the temperature at which the sampler operates
at. Standard CPU-based samplers such as Metropolis- or Gibbs-based
samplers will often default to sampling at an unit temperature, thus a
unit prefactor should be used. In the case of a quantum annealer, a
reasonable choice of a prefactor is 1/beta where beta is the effective
inverse temperature and can be estimated using
:meth:`GraphRestrictedBoltzmannMachine.estimate_beta`.
linear_range (tuple[float, float], optional): Linear weights are clipped to
``linear_range`` prior to sampling. This clipping occurs after the ``prefactor``
scaling has been applied. When None, no clipping is applied. Defaults to None.
quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to
``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor``
scaling has been applied. When None, no clipping is applied.Defaults to None.
device (torch.device, optional): The device of the constructed tensor.
If ``None`` and data is a tensor then the device of data is used.
If ``None`` and data is not a tensor then the result tensor is
constructed on the current device.
sample_params (dict, optional): Parameters of the `sampler.sample` method.
as_tensor (bool): Whether to return the sampleset as a tensor.
Defaults to ``True``. If ``False`` returns a ``dimod.SampleSet``.

Returns:
torch.Tensor | SampleSet: Spins sampled from the model
(shape prescribed by ``sampler`` and ``sample_params``).
"""
if sample_params is None:
sample_params = dict()
h, J = self.to_ising(prefactor, linear_range, quadratic_range)
sample_set = spread(sampler.sample_ising(h, J, **sample_params))

if as_tensor:
return self.sampleset_to_tensor(sample_set, device=device)

return sample_set

def sampleset_to_tensor(
self, sample_set: SampleSet, device: Optional[torch.device] = None
) -> torch.Tensor:
"""Converts a ``dimod.SampleSet`` to a ``torch.Tensor`` using the node order of the class.

Args:
sample_set (dimod.SampleSet): A sample set.
device (torch.device, optional): The device of the constructed tensor.
If ``None`` and data is a tensor then the device of data is used.
If ``None`` and data is not a tensor then the result tensor is constructed
on the current device.

Returns:
torch.Tensor: The sample set as a ``torch.Tensor``.
"""
return sampleset_to_tensor(self._nodes, sample_set, device)

def quasi_objective(
self,
s_observed: torch.Tensor,
Expand Down
16 changes: 16 additions & 0 deletions dwave/plugins/torch/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dwave.plugins.torch.samplers.dimod_sampler import *
from dwave.plugins.torch.samplers._base import *
104 changes: 104 additions & 0 deletions dwave/plugins/torch/samplers/_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import abc
import copy

import torch

from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine


__all__ = ["TorchSampler"]


class TorchSampler(abc.ABC):
"""Base class for all PyTorch plugin samplers."""

def __init__(self, refresh: bool = True) -> None:
self._parameters = {}
self._modules = {}

if refresh:
self.refresh_parameters()

def parameters(self):
"""Parameters in the sampler."""
for p in self._parameters.values():
yield p

def modules(self):
"""Modules in the sampler."""
for m in self._modules.values():
yield m

@abc.abstractmethod
def sample(self, x: torch.Tensor | None = None) -> torch.Tensor:
"""Abstract sample method."""

def to(self, *args, **kwargs):
"""Performs Tensor dtype and/or device conversion on sampler parameters.

See :meth:`torch.Tensor.to` for usage details."""
# perform a shallow copy of the sampler to be returned
sampler = copy.copy(self)
parameters = {}
modules = {}

for name, p in self._parameters.items():
new_p = p.to(*args, **kwargs)

# set attribute and update parameters
setattr(sampler, name, new_p)
parameters[name] = new_p

for name, m in self._modules.items():
new_m = m.to(*args, **kwargs)

# set attribute and update modules
setattr(sampler, name, new_m)
modules[name] = new_m

sampler._parameters = parameters
sampler._modules = modules

return sampler

def refresh_parameters(self, replace=True, clear=True):
"""Refreshes the parameters and modules attributes in-place.

Searches the sampler for any initialized torch parameters and modules
and adds them to the :attr:`TorchSampler_parameters` attribute, which
is used to update device or dtype using the
:meth:`TorchSampler.to` method.

Args:
replace: Replace any previous parameters with new values.
clear: Clear the parameters attribute before adding new ones.
"""
if clear:
self._parameters.clear()
self._modules.clear()

for attr_, val in self.__dict__.items():
# NOTE: Only refreshes torch parameters and modules, but _not_ any
# GRBM models. Can be generalized if plugin gets a baseclass module.
if replace or attr_ not in self._parameters:
if isinstance(val, torch.Tensor):
self._parameters[attr_] = val
elif (
isinstance(val, torch.nn.Module) and
not isinstance(val, GraphRestrictedBoltzmannMachine)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While I don't see an immediate use case, I can imagine a scenario where there are two GRBMs and one could be replaced.
A potential solution is to accept a list of attributes to ignore, which would also addresses the preceding comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any other GRBMs could in this case just be added manually for this to work. I see the refresh method as a non-optimal helper for scanning the module for parameters that should be updated, but it should be checked manually as well in any subclass implementation.

):
self._modules[attr_] = val
107 changes: 107 additions & 0 deletions dwave/plugins/torch/samplers/dimod_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from typing import TYPE_CHECKING, Any

from dimod import Sampler
import torch

from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine
from dwave.plugins.torch.samplers._base import TorchSampler
from dwave.plugins.torch.utils import sampleset_to_tensor
from hybrid.composers import AggregatedSamples

if TYPE_CHECKING:
import dimod
from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine


__all__ = ["DimodSampler"]


class DimodSampler(TorchSampler):
"""PyTorch plugin wrapper for a dimod sampler.

Args:
module (GraphRestrictedBoltzmannMachine): GraphRestrictedBoltzmannMachine module. Requires the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any particular reason to name it module instead of grbm?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. Remnant from wanting it to be a generic module but not having a reasonable base class to use.

methods ``to_ising`` and ``nodes``.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this new interface, to_ising method should be removed from GRBM and implemented elsewhere as a function (not in this sampler either).

sampler (dimod.Sampler): Dimod sampler.
prefactor (float): The prefactor for which the Hamiltonian is scaled by.
This quantity is typically the temperature at which the sampler operates
at. Standard CPU-based samplers such as Metropolis- or Gibbs-based
samplers will often default to sampling at an unit temperature, thus a
unit prefactor should be used. In the case of a quantum annealer, a
reasonable choice of a prefactor is 1/beta where beta is the effective
inverse temperature and can be estimated using
:meth:`GraphRestrictedBoltzmannMachine.estimate_beta`.
linear_range (tuple[float, float], optional): Linear weights are clipped to
``linear_range`` prior to sampling. This clipping occurs after the ``prefactor``
scaling has been applied. When None, no clipping is applied. Defaults to None.
quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to
``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor``
scaling has been applied. When None, no clipping is applied.Defaults to None.
sample_kwargs (dict[str, Any]): Dictionary containing optional arguments for the dimod sampler.
"""

def __init__(
self,
module: GraphRestrictedBoltzmannMachine,
sampler: dimod.Sampler,
prefactor: float | None = None,
linear_range: tuple[float, float] | None = None,
quadratic_range: tuple[float, float] | None = None,
sample_kwargs: dict[str, Any] | None = None
) -> None:
self._module = module

# use default prefactor value of 1
self._prefactor = prefactor or 1

self._linear_range = linear_range
self._quadratic_range = quadratic_range

self._sampler = sampler
self._sampler_params = sample_kwargs or {}

# cached sample_set from latest sample
self._sample_set = None

# adds all torch parameters to 'self._parameters' for automatic device/dtype
# update support unless 'refresh_parameters = False'
super().__init__()

def sample(self, x: torch.Tensor | None = None) -> torch.Tensor:
"""Sample from the dimod sampler and return the corresponding tensor.

The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set`
which is overwritten by subsequent calls.

Args:
x (torch.Tensor): TODO
"""
h, J = self._module.to_ising(self._prefactor, self._linear_range, self._quadratic_range)
self._sample_set = AggregatedSamples.spread(self._sampler.sample_ising(h, J, **self._sampler_params))

# use same device as modules linear
device = self._module._linear.device
return sampleset_to_tensor(self._module.nodes, self._sample_set, device)

@property
def sample_set(self) -> dimod.SampleSet:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove property and keep the sample set a hidden attribute; the public interface should not be concerned with dimod objects.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was requested by @VolodyaCO for the GRBM class previously, so it seems useful, no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VolodyaCO what's the use case?

"""The sample set returned from the latest sample call."""
if self._sample_set is None:
raise AttributeError("no samples found; call 'sample()' first")

return self._sample_set
15 changes: 15 additions & 0 deletions tests/test_samplers/test_annealing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
Loading