diff --git a/dwave/plugins/torch/nn/modules/quantization.py b/dwave/plugins/torch/nn/modules/quantization.py new file mode 100755 index 0000000..a2067c6 --- /dev/null +++ b/dwave/plugins/torch/nn/modules/quantization.py @@ -0,0 +1,59 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from torch import nn + +from dwave.plugins.torch.utils import bit2spin_soft, spin2bit_soft, straight_through_bitrounding + + +class StraightThroughTanh(nn.Module): + def __init__(self): + super().__init__(self, vars()) + self.hth = nn.Tanh() + + def forward(self, x): + fuzzy_spins = self.hth(x) + fuzzy_bits = spin2bit_soft(fuzzy_spins) + bits = straight_through_bitrounding(fuzzy_bits) + spins = bit2spin_soft(bits) + return spins + + +class StraightThroughHardTanh(nn.Module): + def __init__(self): + super().__init__(self, vars()) + self.hth = nn.Hardtanh() + + def forward(self, x): + fuzzy_spins = self.hth(x) + fuzzy_bits = spin2bit_soft(fuzzy_spins) + bits = straight_through_bitrounding(fuzzy_bits) + spins = bit2spin_soft(bits) + return spins + + +class Bit2SpinSoft(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return bit2spin_soft(x) + + +class Spin2BitSoft(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return spin2bit_soft(x) diff --git a/dwave/plugins/torch/utils.py b/dwave/plugins/torch/utils.py index fd142db..dd4855d 100755 --- a/dwave/plugins/torch/utils.py +++ b/dwave/plugins/torch/utils.py @@ -41,3 +41,40 @@ def sampleset_to_tensor( permutation = [var_to_sample_i[v] for v in ordered_vars] sample = sample_set.record.sample[:, permutation] return torch.tensor(sample, dtype=torch.float32, device=device) + + +def straight_through_bitrounding(fuzzy_bits): + if not ((fuzzy_bits >= 0) & (fuzzy_bits <= 1)).all(): + raise ValueError(f"Inputs should be in [0, 1]: {fuzzy_bits}") + bits = fuzzy_bits + (fuzzy_bits.round() - fuzzy_bits).detach() + return bits + + +def bit2spin_soft(b): + if not ((b >= 0) & (b <= 1)).all(): + raise ValueError(f"Not all inputs are in [0, 1]: {b}") + return b * 2.0 - 1.0 + + +def spin2bit_soft(s): + if (s.abs() > 1).any(): + raise ValueError(f"Not all inputs are in [-1, 1]: {s}") + return (s + 1.0) / 2.0 + + +def rands_like(x): + return rands(x.shape, device=x.device) + + +def randb_like(x): + return randb(x.shape, device=x.device) + + +def randb(shape, device=None): + return torch.randint(0, 2, shape, device=device) + + +def rands(shape, device=None): + if isinstance(shape, int): + shape = (shape,) + return bit2spin_soft(torch.randint(0, 2, shape, device=device))