diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py new file mode 100755 index 0000000..8327f7a --- /dev/null +++ b/dwave/plugins/torch/nn/functional.py @@ -0,0 +1,84 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Functional interface.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +import torch + +__all__ = ["maximum_mean_discrepancy_loss"] + + + +def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: + """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. + + The `squared MMD `_ is defined as + + .. math:: + MMD^2(X, Y) = |E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] |^2, + + where :math:`\varphi` is a feature map associated with the kernel function + :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the + distributions of the samples. It follows that, in terms of the kernel function, the squared MMD + can be computed as + + .. math:: + E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. + + If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss + function for minimizing the distance between the model distribution and data distribution. + + For more information, see + Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). + A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. + kernel (Kernel): A kernel function object. + + Raises: + ValueError: If the sample size of ``x`` or ``y`` is less than two. + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + + Returns: + torch.Tensor: The squared maximum mean discrepancy estimate. + """ + num_x = x.shape[0] + num_y = y.shape[0] + if num_x < 2 or num_y < 2: + raise ValueError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) + if x.shape[1:] != y.shape[1:]: + raise ValueError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) + xy = torch.cat([x, y], dim=0) + kernel_matrix = kernel(xy, xy) + kernel_xx = kernel_matrix[:num_x, :num_x] + kernel_yy = kernel_matrix[num_x:, num_x:] + kernel_xy = kernel_matrix[:num_x, num_x:] + xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) + yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) + xy = kernel_xy.sum() / (num_x * num_y) + return xx + yy - 2 * xy diff --git a/dwave/plugins/torch/nn/modules/kernels.py b/dwave/plugins/torch/nn/modules/kernels.py new file mode 100755 index 0000000..6119cb5 --- /dev/null +++ b/dwave/plugins/torch/nn/modules/kernels.py @@ -0,0 +1,151 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Kernel functions.""" + +from abc import ABC, abstractmethod + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.modules.utils import store_config + +__all__ = ["Kernel", "GaussianKernel"] + + +class Kernel(ABC, nn.Module): + """Base class for kernels. + + `Kernels `_ are functions that compute a similarity + measure between data points. Any ``Kernel`` subclass must implement the ``_kernel`` method, + which computes the kernel matrix for a given input multi-dimensional tensor with shape + (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that + the output is a tensor of shape (n, n) containing the pairwise kernel values. + """ + + @abstractmethod + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Perform a pairwise kernel evaluation over samples. + + Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and + (ny, f1, f2, ..., fk), whose shape is (nx, ny) + containing the pairwise kernel values. + + Args: + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. + + Returns: + torch.Tensor: A (nx, ny) tensor. + """ + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes kernels for all pairs between and within ``x`` and ``y``. + + In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk)- and (n_y, f1, f2, ..., fk)-shaped + tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing pairwise kernel + evaluations. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. + + Raises: + ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) + + Returns: + torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. + """ + if x.shape[1:] != y.shape[1:]: + raise ValueError( + "Input dimensions must match. You are trying to compute " + f"the kernel between tensors of shape {x.shape} and {y.shape}." + ) + if x.shape[0] < 2 or y.shape[0] < 2: + raise ValueError( + "Sample size of ``x`` and ``y`` must be at least two. " + f"Got, respectively, {x.shape} and {y.shape}." + ) + return self._kernel(x, y) + + +class GaussianKernel(Kernel): + """The Gaussian kernel. + + This kernel between two data points x and y is defined as + :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth + parameter. + + This implementation considers aggregating multiple Gaussian kernels with different + bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of + multipliers. The base bandwidth can be provided directly or estimated from the data using the + average distance between samples. + + Args: + n_kernels (int): Number of kernel bandwidths to use. + factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are + computed as :math:`\sigma_i = \sigma * factor^{i - n\_kernels // 2}` for + :math:`i` in ``[0, n\_kernels - 1]``. Defaults to 2.0. + bandwidth (float | None): Base bandwidth parameter. If ``None``, the bandwidth is computed + from the data (without gradients). Defaults to ``None``. + """ + + @store_config + def __init__( + self, n_kernels: int, factor: int | float = 2.0, bandwidth: float | None = None + ): + super().__init__() + factors = factor ** (torch.arange(n_kernels) - n_kernels // 2) + self.register_buffer("factors", factors) + self.bandwidth = bandwidth + + @torch.no_grad() + def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: + """Heuristically determine a bandwidth parameter as the average distance between samples. + + Computes the base bandwidth parameter as the average distance between samples if the + bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. + See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking + the average distance as the bandwidth. + + Args: + distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise + L2 distances between samples. If it is ``None`` and the bandwidth is not provided, + an error will be raised. Defaults to ``None``. + + Returns: + torch.Tensor | float: The base bandwidth parameter. + """ + if self.bandwidth is None: + num_samples = distance_matrix.shape[0] + return distance_matrix.sum() / (num_samples**2 - num_samples) + return self.bandwidth + + def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Compute the Gaussian kernel between ``x`` and ``y``. + + .. math:: + k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), + + where :math:`\sigma_i` are the bandwidths. + + Args: + x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. + y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. + + Returns: + torch.Tensor: A (nx, ny) tensor representing the kernel matrix. + """ + distance_matrix = torch.cdist(x.flatten(1), y.flatten(1), p=2) + bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.factors + return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) diff --git a/dwave/plugins/torch/nn/modules/loss.py b/dwave/plugins/torch/nn/modules/loss.py new file mode 100755 index 0000000..45488ec --- /dev/null +++ b/dwave/plugins/torch/nn/modules/loss.py @@ -0,0 +1,56 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.utils import store_config + +if TYPE_CHECKING: + from dwave.plugins.torch.nn.modules.kernels import Kernel + +__all__ = ["MaximumMeanDiscrepancyLoss"] + + +class MaximumMeanDiscrepancyLoss(nn.Module): + """An unbiased estimator for the squared maximum mean discrepancy (MMD) as a loss function. + + This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to + compute the loss. + + Args: + kernel (Kernel): A kernel function object. + """ + + @store_config + def __init__(self, kernel: Kernel) -> None: + super().__init__() + self.kernel = kernel + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes the MMD loss between two sets of samples ``x`` and ``y``. + + Args: + x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. + y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. + + Returns: + torch.Tensor: The computed MMD loss. + """ + return mmd_loss(x, y, self.kernel) diff --git a/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml new file mode 100644 index 0000000..46ea631 --- /dev/null +++ b/releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + Add a ``MaximumMeanDiscrepancyLoss`` in ``dwave.plugins.torch.nn.loss`` for estimating the + squared maximum mean discrepancy (MMD) for a given kernel and two samples. + Its functional counterpart ``maximum_mean_discrepancy_loss`` is in + ``dwave.plugins.torch.nn.functional``. + Kernels reside in ``dwave.plugins.torch.nn.modules.kernels``. This enables, for example, + training discrete autoencoders to match the distribution of a target distribution (e.g., prior). + diff --git a/tests/requirements.txt b/tests/requirements.txt index b2bd102..d7abc8f 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -1,3 +1,4 @@ coverage codecov parameterized +einops diff --git a/tests/test_dvae_winci2020.py b/tests/test_dvae_winci2020.py index 8cd5db8..01825b9 100644 --- a/tests/test_dvae_winci2020.py +++ b/tests/test_dvae_winci2020.py @@ -15,13 +15,16 @@ import unittest import torch +from einops import repeat from parameterized import parameterized from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine -from dwave.plugins.torch.models.discrete_variational_autoencoder import ( - DiscreteVariationalAutoencoder as DVAE, -) +from dwave.plugins.torch.models.discrete_variational_autoencoder import \ + DiscreteVariationalAutoencoder as DVAE from dwave.plugins.torch.models.losses.kl_divergence import pseudo_kl_divergence_loss +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import GaussianKernel +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss from dwave.samplers import SimulatedAnnealingSampler @@ -84,6 +87,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.dvaes = {i: DVAE(self.encoders[i], self.decoders[i]) for i in latent_dims_list} + # Now we also create a DVAE with a trainable Encoder + def deterministic_latent_to_discrete(logits: torch.Tensor, n_samples: int) -> torch.Tensor: + # straight-through estimator that maps positive logits to 1 and negative logits to -1 + hard = torch.sign(logits) + soft = logits + result = hard - soft.detach() + soft + # Now we need to repeat the result n_samples times along a new dimension + return repeat(result, "b ... -> b n ...", n=n_samples) + + self.dvae_with_trainable_encoder = DVAE( + encoder=torch.nn.Linear(input_features, latent_features), + decoder=Decoder(latent_features, input_features), + latent_to_discrete=deterministic_latent_to_discrete, + ) + + self.fixed_boltzmann_machine = GraphRestrictedBoltzmannMachine( + nodes=(0, 1), + edges=[(0, 1)], + linear={0: 0.0, 1: 0.0}, + quadratic={(0, 1): 0.0}, + ) # Creates a uniform distribution over spin strings of length 2 + self.boltzmann_machine = GraphRestrictedBoltzmannMachine( nodes=(0, 1), edges=[(0, 1)], @@ -110,6 +135,49 @@ def test_mappings(self): # map [0, 1] to [-1, 1]: torch.testing.assert_close(torch.tensor([-1, 1]).float(), discretes[3]) + @parameterized.expand([True, False]) + def test_train_encoder_with_mmd(self, use_mmd_loss_class: bool = False): + """Test training the encoder of the DVAE with MMD loss and fixed decoder and GRBM prior.""" + dvae = self.dvae_with_trainable_encoder + optimiser = torch.optim.SGD(dvae.encoder.parameters(), lr=0.01, momentum=0.9) + kernel = GaussianKernel(n_kernels=5, factor=2.0, bandwidth=None) + # Before training, the encoder will not map data points to the correct spin strings: + expected_set = {(1.0, 1.0), (1.0, -1.0), (-1.0, -1.0), (-1.0, 1.0)} + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertNotEqual(discretes_set, expected_set) + mmd_loss_module = None + # Train the encoder so that the latent distribution matches the prior GRBM distribution + for _ in range(1000): + optimiser.zero_grad() + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.reshape(discretes.shape[0], -1) + prior_samples = self.fixed_boltzmann_machine.sample( + sampler=self.sampler_sa, + as_tensor=True, + device=discretes.device, + prefactor=1.0, + linear_range=None, + quadratic_range=None, + sample_params=dict(num_sweeps=10, seed=1234, num_reads=100), + ) + if use_mmd_loss_class: + if mmd_loss_module is None: + mmd_loss_module = MMDLoss(kernel) + mmd = mmd_loss_module(discretes, prior_samples) + else: + mmd = mmd_loss(discretes, prior_samples, kernel) + mmd.backward() + optimiser.step() + # After training, the encoder should map data points to spin strings that match the samples + # from the prior GRBM. Since the prior GRBM is uniform over spin strings of length 2, the + # encoder should map the four data points to the four spin strings (in any order). + _, discretes, _ = dvae(self.data, n_samples=1) + discretes = discretes.squeeze(1) + discretes_set = {tuple(row.tolist()) for row in discretes} + self.assertEqual(discretes_set, expected_set) + @parameterized.expand([1, 2]) def test_train(self, n_latent_dims): """Test training simple dataset.""" diff --git a/tests/test_functional.py b/tests/test_functional.py new file mode 100755 index 0000000..17f81f4 --- /dev/null +++ b/tests/test_functional.py @@ -0,0 +1,79 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import Kernel + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + def test_mmd_loss_constant(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3], [0.5]]) + + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + # kxx = (4 + 4)/2 + # kyy = (3 + 3)/2 + # kxy = (0 + 1 + 4 + 2)/4 + # kxx + kyy -2kxy = 4 + 3 - 3.5 = 3.5 + self.assertEqual(3.5, mmd_loss(x, y, kernel)) + + def test_sample_size_error(self): + x = torch.tensor([[1.2], [4.1]]) + y = torch.tensor([[0.3]]) + self.assertRaisesRegex(ValueError, "must be at least two", mmd_loss, x, y, None) + + def test_mmd_loss_dim_mismatch(self): + x = torch.tensor([[1], [4]], dtype=torch.float32) + y = torch.tensor([[0.1, 0.2, 0.3], + [0.4, 0.5, 0.6]]) + self.assertRaisesRegex(ValueError, "Input dimensions must match. You are trying to compute ", mmd_loss, x, y, None) + + def test_mmd_loss_arange(self): + x = torch.tensor([[1.0], [4.0], [5.0]]) + y = torch.tensor([[0.3], [0.4]]) + + class Constant(Kernel): + def _kernel(self, x, y): + return torch.tensor([[150, 22, 39, 34, 28], + [22, 630, 98, 56, 44], + [39, 98, 560, 78, 33], + [-99, -99, -99, 299, 13], + [-99, -99, -99, 13, 970]], dtype=torch.float32) + + mmd_loss(x, y, Constant()) + # NOTE: calculation takes kxy = upper-right corner; no PSD assumption + # kxx = (22+39+98)/3 + # kyy = 13 + # kxy = (34+28+56+44+78+33)/6 + # kxx + kyy - 2*kxy + # kxx + kyy - 2*kxy = -25.0 + self.assertEqual(-25, mmd_loss(x, y, Constant())) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_kernels.py b/tests/test_kernels.py new file mode 100755 index 0000000..278ca53 --- /dev/null +++ b/tests/test_kernels.py @@ -0,0 +1,110 @@ +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.modules.kernels import Kernel, GaussianKernel + + +class TestKernel(unittest.TestCase): + def test_forward(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 3)) + y = torch.randn((9, 3)) + self.assertEqual(1, one(x, y)) + + @parameterized.expand([(1, 2), (2, 1)]) + def test_sample_size(self, nx, ny): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((nx, 5)) + y = torch.randn((ny, 5)) + self.assertRaisesRegex(ValueError, "must be at least two", one, x, y) + + def test_shape_mismatch(self): + class One(Kernel): + def _kernel(self, x, y): + return 1 + one = One() + x = torch.rand((5, 4)) + y = torch.randn((9, 3)) + self.assertRaisesRegex(ValueError, "Input dimensions must match", one, x, y) + +class TestGaussianKernel(unittest.TestCase): + + def test_has_config(self): + rbf = GaussianKernel(5, 2.1, 0.1) + self.assertDictEqual(dict(rbf.config), dict(module_name="GaussianKernel", + n_kernels=5, factor=2.1, bandwidth=0.1)) + + @parameterized.expand([ + (torch.randn((5, 12)), torch.rand((7, 12))), + (torch.randn((5, 12, 34)), torch.rand((7, 12, 34))), + ]) + def test_shape(self, x, y): + rbf = GaussianKernel(2, 2.1, 0.1) + k = rbf(x, y) + self.assertEqual(tuple(k.shape), (x.shape[0], y.shape[0])) + + def test_get_bandwidth_default(self): + rbf = GaussianKernel(2, 2.1, 0.1) + d = torch.tensor(123) + self.assertEqual(0.1, rbf._get_bandwidth(d)) + + def test_get_bandwidth(self): + rbf = GaussianKernel(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]]) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + + def test_get_bandwidth_no_grad(self): + rbf = GaussianKernel(2, 2.1, None) + d = torch.tensor([[0.0, 3.4,], [3.4, 0.0]], requires_grad=True) + self.assertEqual(3.4, rbf._get_bandwidth(d)) + self.assertIsNone(rbf._get_bandwidth(d).grad) + + def test_single_factors(self): + rbf = GaussianKernel(1, 2.1, None) + self.assertListEqual(rbf.factors.tolist(), [1.0]) + + def test_two_factors(self): + rbf = GaussianKernel(2, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1]), rbf.factors) + + def test_three_factors(self): + rbf = GaussianKernel(3, 2.1, None) + torch.testing.assert_close(torch.tensor([2.1**-1, 1, 2.1]), rbf.factors) + + def test_kernel(self): + x = torch.tensor([[1.0, 1.0], + [2.0, 3.0]], requires_grad=True) + y = torch.tensor([[0.0, 1.0], + [-3.0, 5.0], + [1.2, 9.0]], requires_grad=True) + dist = torch.cdist(x, y) + + with self.subTest("Adaptive bandwidth"): + rbf = GaussianKernel(1, 2.1, None) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Simple bandwidth"): + rbf = GaussianKernel(1, 2.1, 12.34) + bandwidths = 12.34 * rbf.factors + manual = torch.exp(-dist/bandwidths) + torch.testing.assert_close(manual, rbf(x, y)) + + with self.subTest("Multiple kernels"): + rbf = GaussianKernel(3, 2.1, 123) + bandwidths = rbf._get_bandwidth(dist) * rbf.factors + manual = torch.exp(-dist/bandwidths.reshape(-1, 1, 1)).sum(0) + torch.testing.assert_close(manual, rbf(x, y)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_loss.py b/tests/test_loss.py new file mode 100755 index 0000000..e59dbe3 --- /dev/null +++ b/tests/test_loss.py @@ -0,0 +1,47 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import unittest + +import torch +from parameterized import parameterized + +from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.modules.kernels import Kernel +from dwave.plugins.torch.nn.modules.loss import MaximumMeanDiscrepancyLoss as MMDLoss + + +class TestMaximumMeanDiscrepancyLoss(unittest.TestCase): + @parameterized.expand([ + (torch.tensor([[1.2], [4.1]]), torch.tensor([[0.3], [0.5]])), + (torch.randn((123, 4, 3, 2)), torch.rand(100, 4, 3, 2)), + ]) + def test_mmd_loss(self, x, y): + class Constant(Kernel): + def __init__(self): + super().__init__() + self.k = torch.tensor([[10, 4, 0, 1], + [4, 10, 4, 2], + [0, 4, 10, 3], + [1, 2, 3, 10]]).float() + + def _kernel(self, x, y): + return self.k + # The resulting kernel matrix will be constant, so (averages) KXX = KYY = 2KXY + kernel = Constant() + compute_mmd = MMDLoss(kernel) + torch.testing.assert_close(mmd_loss(x, y, kernel), compute_mmd(x, y)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_nn.py b/tests/test_nn.py index c84929d..bac40c9 100755 --- a/tests/test_nn.py +++ b/tests/test_nn.py @@ -1,3 +1,16 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import unittest import torch