diff --git a/README.md b/README.md index fd00bcaa6bd..c77320656fd 100644 --- a/README.md +++ b/README.md @@ -39,13 +39,15 @@ ______________________________________________________________________ # Looking for GPUs? -Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. -- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. -- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. + +Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. + +- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. +- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train. -- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. +- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. - [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze. -- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. +- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. # Installation diff --git a/docs/source/index.rst b/docs/source/index.rst index 7a8b1f5ff7d..bf766861875 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -253,6 +253,14 @@ Or directly from conda :caption: Video :glob: + timeseries/* + +.. toctree:: + :maxdepth: 2 + :name: timeseries + :caption: Time Series + :glob: + video/* .. toctree:: diff --git a/docs/source/links.rst b/docs/source/links.rst index 539d2728e74..3b89f00cd02 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -184,3 +184,4 @@ .. _Deep Image Structure and Texture Similarity: https://arxiv.org/abs/2004.07728 .. _KonIQ-10k: https://database.mmsp-kn.de/koniq-10k-database.html .. _KADID-10k: https://database.mmsp-kn.de/kadid-10k-database.html +.. _SoftDTW: https://arxiv.org/abs/1703.01541 diff --git a/docs/source/timeseries/softdtw.rst b/docs/source/timeseries/softdtw.rst new file mode 100644 index 00000000000..d2d99958645 --- /dev/null +++ b/docs/source/timeseries/softdtw.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Soft Dynamic Time Warping + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: timeseries + +.. include:: ../links.rst + +######################### +Soft Dynamic Time Warping +######################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.timeseries.SoftDTW + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.timeseries.soft_dtw diff --git a/requirements/timeseries_test.txt b/requirements/timeseries_test.txt new file mode 100644 index 00000000000..62d6031ab0f --- /dev/null +++ b/requirements/timeseries_test.txt @@ -0,0 +1 @@ +pysdtw==0.0.5 diff --git a/src/torchmetrics/functional/timeseries/__init__.py b/src/torchmetrics/functional/timeseries/__init__.py new file mode 100644 index 00000000000..81ef5187653 --- /dev/null +++ b/src/torchmetrics/functional/timeseries/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.timeseries.softdtw import soft_dtw + +__all__ = ["soft_dtw"] diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py new file mode 100644 index 00000000000..74964e154d1 --- /dev/null +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -0,0 +1,163 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Callable, Literal, Optional + +import torch +from torch import Tensor + + +def _soft_dtw_validate_args( + preds: Tensor, target: Tensor, gamma: float, reduction: Literal["mean", "sum", "none"] +) -> None: + """Validate the input arguments for the soft_dtw function.""" + valid_reduction = ("mean", "sum", "none") + if reduction not in valid_reduction: + raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") + if preds.ndim != 3 or target.ndim != 3: + raise ValueError("Inputs preds and target must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") + if preds.shape[0] != target.shape[0]: + raise ValueError("Batch size of preds and target must be the same.") + if preds.shape[2] != target.shape[2]: + raise ValueError("Feature dimension of preds and target must be the same.") + if not isinstance(gamma, float) or gamma <= 0: + raise ValueError("Gamma must be a positive float.") + + +def _soft_dtw_update(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: + """Compute the Soft-DTW distance between two batched sequences.""" + b, n, d = preds.shape + _, m, _ = target.shape + device, dtype = target.device, target.dtype + if preds.dtype != target.dtype: + target = target.to(preds.dtype) + + if distance_fn is None: + + def distance_fn(x: Tensor, y: Tensor) -> Tensor: + """Default to squared Euclidean distance.""" + return torch.cdist(x, y, p=2).pow(2) + + distances = distance_fn(preds, target) # [B, N, M] + + r = torch.ones((b, n + 2, m + 2), device=device, dtype=dtype) * math.inf + r[:, 0, 0] = 0.0 + + def softmin(a: Tensor, b: Tensor, c: Tensor, gamma: float) -> Tensor: + """Compute the soft minimum of three tensors.""" + vals = torch.stack([a, b, c], dim=-1) + return -gamma * torch.logsumexp(-vals / gamma, dim=-1) + + # Anti-diagonal approach inspired from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8400444 + for k in range(2, n + m + 1): + i_vals = torch.arange(1, n + 1, device=device) + j_vals = k - i_vals + mask = (j_vals >= 1) & (j_vals <= m) + i_vals = i_vals[mask] + j_vals = j_vals[mask] + + if len(i_vals) == 0: + continue + + r1 = r[:, i_vals - 1, j_vals - 1] + r2 = r[:, i_vals - 1, j_vals] + r3 = r[:, i_vals, j_vals - 1] + r[:, i_vals, j_vals] = distances[:, i_vals - 1, j_vals - 1] + softmin(r1, r2, r3, gamma) + + return r[:, n, m] + + +def _soft_dtw_compute(scores: Tensor, reduction: Literal["sum", "mean", "none"] = "mean") -> Tensor: + """Aggregate the computed Soft-DTW distances based on the specified reduction method.""" + if reduction == "none": + return scores + if reduction == "mean": + return scores.mean() + return scores.sum() + + +def soft_dtw( + preds: Tensor, + target: Tensor, + gamma: float = 1.0, + distance_fn: Optional[Callable] = None, + reduction: Literal["sum", "mean", "none"] = "mean", +) -> Tensor: + r"""Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences. + + This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by + Marco Cuturi and Mathieu Blondel (2017). + It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation: + + .. math:: + \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) + + The Soft-DTW recurrence is then defined as: + + .. math:: + R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) + + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. It could be + computed using any differentiable distance function, such as squared Euclidean distance or cosine distance. + + The final Soft-DTW distance is :math:`R_{N,M}`. + + Args: + preds: Tensor of shape ``[B, N, D]`` — batch of input sequences. + target: Tensor of shape ``[B, M, D]`` — batch of target sequences. + gamma: Smoothing parameter (:math:`\gamma > 0`). + Smaller values make the loss closer to standard DTW (hard minimum), + while larger values produce a smoother and more differentiable surface. + distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. + If ``None``, defaults to squared Euclidean distance. + reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. + Defaults to ``mean``. + + Returns: + A tensor of shape ``[B]`` containing the Soft-DTW distance for each sequence pair in the batch. + + Raises: + ValueError: + If ``reduction`` is not one of [``sum``, ``mean``, ``none``]. + ValueError: + If ``gamma`` is not a positive float. + ValueError: + If input tensors to ``preds`` and ``target`` are not 3-dimensional + with the same batch size and feature dimension. + + Example:: + >>> import torch + >>> from torchmetrics.functional.timeseries import soft_dtw + >>> + >>> x = torch.tensor([[[0.0], [1.0], [2.0]]]) # [B, N, D] + >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] + >>> soft_dtw(x, y, gamma=0.1) + tensor([0.4003]) + + + Example (custom distance function):: + >>> def cosine_dist(a, b): + ... a = torch.nn.functional.normalize(a, dim=-1) + ... b = torch.nn.functional.normalize(b, dim=-1) + ... return 1 - torch.bmm(a, b.transpose(1, 2)) + >>> + >>> x = torch.randn(2, 5, 3) + >>> y = torch.randn(2, 6, 3) + >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) + tensor([2.8301, 3.0128]) + + """ + _soft_dtw_validate_args(preds, target, gamma, reduction) + scores = _soft_dtw_update(preds, target, gamma, distance_fn) + return _soft_dtw_compute(scores, reduction) diff --git a/src/torchmetrics/timeseries/__init__.py b/src/torchmetrics/timeseries/__init__.py new file mode 100644 index 00000000000..298826554a6 --- /dev/null +++ b/src/torchmetrics/timeseries/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.timeseries.softdtw import SoftDTW + +__all__ = ["SoftDTW"] diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py new file mode 100644 index 00000000000..7878528663d --- /dev/null +++ b/src/torchmetrics/timeseries/softdtw.py @@ -0,0 +1,163 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import warnings +from collections.abc import Sequence +from typing import Any, Callable, List, Literal, Optional, Union + +import torch +from torch import Tensor + +from torchmetrics import Metric +from torchmetrics.functional.timeseries.softdtw import soft_dtw +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SoftDTW.plot"] + + +class SoftDTW(Metric): + r"""Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences. + + This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by + Marco Cuturi and Mathieu Blondel (2017). + It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation: + + .. math:: + \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) + + The Soft-DTW recurrence is then defined as: + + .. math:: + R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) + + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. It could be + computed using any differentiable distance function, such as squared Euclidean distance or cosine distance. + + The final Soft-DTW distance is :math:`R_{N,M}`. + + Args: + gamma: Smoothing parameter (:math:`\gamma > 0`). + Smaller values make the loss closer to standard DTW (hard minimum), + while larger values produce a smoother and more differentiable surface. + distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. + If ``None``, defaults to squared Euclidean distance. + reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. + + Raises: + ValueError: + If ``reduction`` is not one of [``sum``, ``mean``, ``none``]. + ValueError: + If ``gamma`` is not a positive float. + ValueError: + If input tensors to ``update`` are not 3-dimensional + with the same batch size and feature dimension. + + Example: + >>> from torch import randn + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW(gamma=0.1) + >>> x = randn(10, 50, 2) + >>> y = randn(10, 60, 2) + >>> metric(x, y) + tensor(43.2051) + + """ + + full_state_update: bool = False + is_differentiable: bool = True + higher_is_better: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + pred_list: List[Tensor] + gt_list: List[Tensor] + + def __init__( + self, + distance_fn: Optional[Callable] = None, + gamma: float = 1.0, + reduction: Literal["sum", "mean", "none"] = "mean", + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.distance_fn = distance_fn + valid_reduction = ("mean", "sum", "none") + if reduction not in valid_reduction: + raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") + if gamma <= 0: + raise ValueError(f"Argument `gamma` must be a positive float, got {gamma}") + self.gamma = gamma + self.reduction = reduction + + if self.device.type == "cpu": # warn on cpu + warnings.warn("SoftDTW is slow on CPU. Consider using a GPU.", stacklevel=2) + + self.add_state("pred_list", default=[], dist_reduce_fx="cat") + self.add_state("gt_list", default=[], dist_reduce_fx="cat") + + def update(self, x: torch.Tensor, y: torch.Tensor) -> None: + """Update the Procrustes Disparity with the given datasets.""" + self.pred_list.append(x) + self.gt_list.append(y) + + def compute(self) -> torch.Tensor: + """Computes the Procrustes Disparity.""" + return soft_dtw( + torch.cat(self.pred_list, dim=0), + torch.cat(self.gt_list, dim=0), + gamma=self.gamma, + distance_fn=self.distance_fn, + reduction=self.reduction, + ) + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW() + >>> metric.update(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py new file mode 100644 index 00000000000..9e26ac0a7ee --- /dev/null +++ b/tests/unittests/timeseries/test_softdtw.py @@ -0,0 +1,147 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial + +import pysdtw +import pytest +import torch +from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + +from torchmetrics.functional.timeseries.softdtw import soft_dtw +from torchmetrics.timeseries.softdtw import SoftDTW + +seed_all(42) + +_inputs = _Input( + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, 20, 3, dtype=torch.float64), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, 30, 3, dtype=torch.float64), +) + + +def _reference_softdtw( + preds: torch.Tensor, target: torch.Tensor, gamma: float = 1.0, distance_fn=None, reduction: str = "mean" +) -> torch.Tensor: + """Reference implementation using tslearn's soft-DTW.""" + preds = preds.to("cuda" if torch.cuda.is_available() else "cpu") + target = target.to("cuda" if torch.cuda.is_available() else "cpu") + sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=bool(torch.cuda.is_available())) + if reduction == "mean": + return sdtw(preds, target).mean() + if reduction == "sum": + return sdtw(preds, target).sum() + return sdtw(preds, target) + + +def euclidean_distance(x, y): + """Squared Euclidean distance.""" + return torch.cdist(x, y, p=2) + + +def manhattan_distance(x, y): + """L1 (Manhattan) distance.""" + return torch.cdist(x, y, p=1) + + +def cosine_distance(x, y): + """Cosine distance.""" + x_norm = x / x.norm(dim=-1, keepdim=True) + y_norm = y / y.norm(dim=-1, keepdim=True) + return 1 - torch.matmul(x_norm, y_norm.transpose(-1, -2)) + + +@pytest.mark.parametrize(("preds", "target"), [(_inputs.preds, _inputs.target)]) +class TestSoftDTW(MetricTester): + """Test class for `SoftDTW` metric.""" + + @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_softdtw_class(self, gamma, preds, target, distance_fn, reduction, ddp): + """Test class implementation of SoftDTW.""" + self.run_class_metric_test( + ddp, + preds, + target, + SoftDTW, + partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn, reduction=reduction), + metric_args={"gamma": gamma, "distance_fn": distance_fn, "reduction": reduction}, + ) + + @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + def test_softdtw_functional(self, preds, target, gamma, distance_fn, reduction): + """Test functional implementation of SoftDTW.""" + self.run_functional_metric_test( + preds, + target, + metric_functional=soft_dtw, + reference_metric=partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn, reduction=reduction), + metric_args={"gamma": gamma, "distance_fn": distance_fn, "reduction": reduction}, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") + def test_softdtw_differentiability(self, preds, target): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SoftDTW, + metric_functional=soft_dtw, + metric_args={"gamma": 1.0}, + ) + + +def test_wrong_dimensions(): + """Test that an error is raised if input tensors have wrong dimensions.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Inputs preds and target must be 3-dimensional tensors of shape*"): + metric(torch.randn(10, 100), torch.randn(10, 100, 3)) + + +def test_mismatched_dimensions(): + """Test that an error is raised if input dimensions don't match.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Batch size of preds and target must be the same.*"): + metric(torch.randn(10, 80, 3), torch.randn(12, 100, 3)) + + +def test_mismatched_feature_dimensions(): + """Test that an error is raised if input feature dimensions don't match.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Feature dimension of preds and target must be the same.*"): + metric(torch.randn(10, 80, 3), torch.randn(10, 100, 4)) + + +def test_invalid_gamma(): + """Test that an error is raised if gamma is not a positive float.""" + with pytest.raises(ValueError, match="Argument `gamma` must be a positive float, got -1.0*"): + SoftDTW(gamma=-1.0) + + +def test_warning_on_cpu(): + """Test that a warning is raised if SoftDTW is used on CPU.""" + if torch.cuda.is_available(): + pytest.skip("Test only runs on CPU.") + with pytest.warns(UserWarning, match="SoftDTW is slow on CPU. Consider using a GPU.*"): + SoftDTW() + + +def test_invalid_reduction(): + """Test that an error is raised if reduction is not one of [``sum``, ``mean``, ``none``].""" + with pytest.raises(ValueError, match="Argument `reduction` must be one of .*"): + SoftDTW(reduction="invalid") diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 4d973a790c9..8deeca9322a 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -183,6 +183,7 @@ WordInfoLost, WordInfoPreserved, ) +from torchmetrics.timeseries import SoftDTW from torchmetrics.utilities.plot import _get_col_row_split from torchmetrics.wrappers import ( BootStrapper, @@ -692,6 +693,12 @@ lambda: torch.randn(10, 100, 3), id="lip vertex error", ), + pytest.param( + SoftDTW, + lambda: torch.randn(10, 4, 3), + lambda: torch.randn(10, 6, 3), + id="soft dtw", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 3])