-
Notifications
You must be signed in to change notification settings - Fork 11
[WIP] Add samplers submodule #58
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from dwave.plugins.torch.samplers.dimod_sampler import * | ||
| from dwave.plugins.torch.samplers._base import * |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import abc | ||
| import copy | ||
|
|
||
| import torch | ||
|
|
||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
|
|
||
|
|
||
| __all__ = ["TorchSampler"] | ||
|
|
||
|
|
||
| class TorchSampler(abc.ABC): | ||
| """Base class for all PyTorch plugin samplers.""" | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, refresh: bool = True) -> None: | ||
| self._parameters = {} | ||
| self._modules = {} | ||
|
|
||
| if refresh: | ||
| self.refresh_parameters() | ||
|
|
||
| def parameters(self): | ||
| """Parameters in the sampler.""" | ||
| for p in self._parameters.values(): | ||
| yield p | ||
|
|
||
| def modules(self): | ||
| """Modules in the sampler.""" | ||
| for m in self._modules.values(): | ||
| yield m | ||
|
|
||
| @abc.abstractmethod | ||
| def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | ||
| """Abstract sample method.""" | ||
|
|
||
| def to(self, *args, **kwargs): | ||
thisac marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Performs Tensor dtype and/or device conversion on sampler parameters. | ||
|
|
||
| See :meth:`torch.Tensor.to` for usage details.""" | ||
| # perform a shallow copy of the sampler to be returned | ||
| sampler = copy.copy(self) | ||
| parameters = {} | ||
| modules = {} | ||
|
|
||
| for name, p in self._parameters.items(): | ||
| new_p = p.to(*args, **kwargs) | ||
|
|
||
| # set attribute and update parameters | ||
| setattr(sampler, name, new_p) | ||
| parameters[name] = new_p | ||
|
|
||
| for name, m in self._modules.items(): | ||
| new_m = m.to(*args, **kwargs) | ||
|
|
||
| # set attribute and update modules | ||
| setattr(sampler, name, new_m) | ||
| modules[name] = new_m | ||
|
|
||
| sampler._parameters = parameters | ||
| sampler._modules = modules | ||
|
|
||
| return sampler | ||
|
|
||
| def refresh_parameters(self, replace=True, clear=True): | ||
| """Refreshes the parameters and modules attributes in-place. | ||
|
|
||
| Searches the sampler for any initialized torch parameters and modules | ||
| and adds them to the :attr:`TorchSampler_parameters` attribute, which | ||
| is used to update device or dtype using the | ||
| :meth:`TorchSampler.to` method. | ||
|
|
||
| Args: | ||
| replace: Replace any previous parameters with new values. | ||
| clear: Clear the parameters attribute before adding new ones. | ||
| """ | ||
| if clear: | ||
| self._parameters.clear() | ||
| self._modules.clear() | ||
|
|
||
| for attr_, val in self.__dict__.items(): | ||
| # NOTE: Only refreshes torch parameters and modules, but _not_ any | ||
| # GRBM models. Can be generalized if plugin gets a baseclass module. | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if replace or attr_ not in self._parameters: | ||
| if isinstance(val, torch.Tensor): | ||
| self._parameters[attr_] = val | ||
| elif ( | ||
| isinstance(val, torch.nn.Module) and | ||
| not isinstance(val, GraphRestrictedBoltzmannMachine) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. While I don't see an immediate use case, I can imagine a scenario where there are two
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any other GRBMs could in this case just be added manually for this to work. I see the |
||
| ): | ||
| self._modules[attr_] = val | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING, Any | ||
|
|
||
| from dimod import Sampler | ||
| import torch | ||
|
|
||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
| from dwave.plugins.torch.samplers._base import TorchSampler | ||
| from dwave.plugins.torch.utils import sampleset_to_tensor | ||
| from hybrid.composers import AggregatedSamples | ||
|
|
||
| if TYPE_CHECKING: | ||
| import dimod | ||
| from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine | ||
|
|
||
|
|
||
| __all__ = ["DimodSampler"] | ||
|
|
||
|
|
||
| class DimodSampler(TorchSampler): | ||
| """PyTorch plugin wrapper for a dimod sampler. | ||
|
|
||
| Args: | ||
| module (GraphRestrictedBoltzmannMachine): GraphRestrictedBoltzmannMachine module. Requires the | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any particular reason to name it
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not really. Remnant from wanting it to be a generic module but not having a reasonable base class to use. |
||
| methods ``to_ising`` and ``nodes``. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With this new interface, |
||
| sampler (dimod.Sampler): Dimod sampler. | ||
| prefactor (float): The prefactor for which the Hamiltonian is scaled by. | ||
| This quantity is typically the temperature at which the sampler operates | ||
| at. Standard CPU-based samplers such as Metropolis- or Gibbs-based | ||
| samplers will often default to sampling at an unit temperature, thus a | ||
| unit prefactor should be used. In the case of a quantum annealer, a | ||
| reasonable choice of a prefactor is 1/beta where beta is the effective | ||
| inverse temperature and can be estimated using | ||
| :meth:`GraphRestrictedBoltzmannMachine.estimate_beta`. | ||
| linear_range (tuple[float, float], optional): Linear weights are clipped to | ||
| ``linear_range`` prior to sampling. This clipping occurs after the ``prefactor`` | ||
| scaling has been applied. When None, no clipping is applied. Defaults to None. | ||
| quadratic_range (tuple[float, float], optional): Quadratic weights are clipped to | ||
| ``quadratic_range`` prior to sampling. This clipping occurs after the ``prefactor`` | ||
| scaling has been applied. When None, no clipping is applied.Defaults to None. | ||
| sample_kwargs (dict[str, Any]): Dictionary containing optional arguments for the dimod sampler. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| module: GraphRestrictedBoltzmannMachine, | ||
| sampler: dimod.Sampler, | ||
| prefactor: float | None = None, | ||
| linear_range: tuple[float, float] | None = None, | ||
| quadratic_range: tuple[float, float] | None = None, | ||
| sample_kwargs: dict[str, Any] | None = None | ||
| ) -> None: | ||
| self._module = module | ||
|
|
||
| # use default prefactor value of 1 | ||
| self._prefactor = prefactor or 1 | ||
|
|
||
| self._linear_range = linear_range | ||
| self._quadratic_range = quadratic_range | ||
|
|
||
| self._sampler = sampler | ||
| self._sampler_params = sample_kwargs or {} | ||
|
|
||
| # cached sample_set from latest sample | ||
| self._sample_set = None | ||
|
|
||
| # adds all torch parameters to 'self._parameters' for automatic device/dtype | ||
| # update support unless 'refresh_parameters = False' | ||
| super().__init__() | ||
|
|
||
| def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: | ||
| """Sample from the dimod sampler and return the corresponding tensor. | ||
|
|
||
| The sample set returned from the latest sample call is stored in :func:`DimodSampler.sample_set` | ||
| which is overwritten by subsequent calls. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): TODO | ||
| """ | ||
| h, J = self._module.to_ising(self._prefactor, self._linear_range, self._quadratic_range) | ||
| self._sample_set = AggregatedSamples.spread(self._sampler.sample_ising(h, J, **self._sampler_params)) | ||
|
|
||
| # use same device as modules linear | ||
| device = self._module._linear.device | ||
| return sampleset_to_tensor(self._module.nodes, self._sample_set, device) | ||
|
|
||
| @property | ||
| def sample_set(self) -> dimod.SampleSet: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove property and keep the sample set a hidden attribute; the public interface should not be concerned with dimod objects.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was requested by @VolodyaCO for the GRBM class previously, so it seems useful, no?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VolodyaCO what's the use case? |
||
| """The sample set returned from the latest sample call.""" | ||
| if self._sample_set is None: | ||
| raise AttributeError("no samples found; call 'sample()' first") | ||
|
|
||
| return self._sample_set | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest |
Uh oh!
There was an error while loading. Please reload this page.