From 3a5444c4924d1abe8b16f87b72effe36c8abffbc Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Mon, 6 Oct 2025 17:32:25 -0400 Subject: [PATCH 1/7] Initial Commit --- .../functional/timeseries/__init__.py | 17 ++ .../functional/timeseries/softdtw.py | 116 ++++++++++++++ src/torchmetrics/timeseries/__init__.py | 16 ++ src/torchmetrics/timeseries/softdtw.py | 146 ++++++++++++++++++ 4 files changed, 295 insertions(+) create mode 100644 src/torchmetrics/functional/timeseries/__init__.py create mode 100644 src/torchmetrics/functional/timeseries/softdtw.py create mode 100644 src/torchmetrics/timeseries/__init__.py create mode 100644 src/torchmetrics/timeseries/softdtw.py diff --git a/src/torchmetrics/functional/timeseries/__init__.py b/src/torchmetrics/functional/timeseries/__init__.py new file mode 100644 index 00000000000..5f140537f78 --- /dev/null +++ b/src/torchmetrics/functional/timeseries/__init__.py @@ -0,0 +1,17 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.functional.timeseries.softdtw import soft_dtw + +__all__ = ["soft_dtw"] + diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py new file mode 100644 index 00000000000..e783c35dee2 --- /dev/null +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -0,0 +1,116 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional, Callable +import math + +import torch +from torch import Tensor + +def _soft_dtw_validate_args(x: Tensor, y: Tensor, gamma: float) -> None: + """Validate the input arguments for the soft_dtw function.""" + if x.ndim != 3 or y.ndim != 3: + raise ValueError("Inputs x and y must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") + if x.shape[0] != y.shape[0]: + raise ValueError("Batch size of x and y must be the same.") + if x.shape[2] != y.shape[2]: + raise ValueError("Feature dimension of x and y must be the same.") + if not isinstance(gamma, float) or gamma <= 0: + raise ValueError("Gamma must be a positive float.") + +def _soft_dtw_compute(x: Tensor, y: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: + """Compute the Soft-DTW distance between two batched sequences.""" + B, N, D = x.shape + _, M, _ = y.shape + device, dtype = x.device, x.dtype + + if distance_fn is None: + def distance_fn(a, b): + return torch.cdist(a, b, p=2).pow(2) + + D = distance_fn(x, y) # [B, N, M] + + R = torch.ones((B, N + 2, M + 2), device=device, dtype=dtype) * math.inf + R[:, 0, 0] = 0.0 + + def softmin(a, b, c, gamma): + vals = torch.stack([a, b, c], dim=-1) + return -gamma * torch.logsumexp(-vals / gamma, dim=-1) + + for i in range(1, N + 1): + for j in range(1, M + 1): + r1 = R[:, i-1, j-1] + r2 = R[:, i-1, j] + r3 = R[:, i, j-1] + R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) + + return R[:, N, M] + +def soft_dtw( + x: Tensor, + y: Tensor, + gamma: float = 1.0, + distance_fn=None, +) -> Tensor: + r""" + Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + + This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by + Marco Cuturi and Mathieu Blondel (2017). + It replaces the hard minimum in DTW recursion with a *soft-minimum* using a log-sum-exp formulation: + + .. math:: + \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) + + The Soft-DTW recurrence is then defined as: + + .. math:: + R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) + + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. + + The final Soft-DTW distance is :math:`R_{N,M}`. + + Args: + x: Tensor of shape ``[B, N, D]`` — batch of input sequences. + y: Tensor of shape ``[B, M, D]`` — batch of target sequences. + gamma: Smoothing parameter (:math:`\gamma > 0`). Smaller values make the loss closer to standard DTW (hard minimum), + while larger values produce a smoother and more differentiable surface. + distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. + If ``None``, defaults to **squared Euclidean distance**. + + Returns: + A tensor of shape ``[B]`` containing the Soft-DTW distance for each sequence pair in the batch. + + Example:: + >>> import torch + >>> from torchmetrics.functional.timeseries import soft_dtw + >>> + >>> x = torch.tensor([[[0.0], [1.0], [2.0]]]) # [B, N, D] + >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] + >>> soft_dtw(x, y, gamma=0.1) + tensor([0.4003]) + + Example (custom distance function):: + >>> def cosine_dist(a, b): + ... a = torch.nn.functional.normalize(a, dim=-1) + ... b = torch.nn.functional.normalize(b, dim=-1) + ... return 1 - torch.bmm(a, b.transpose(1, 2)) + >>> + >>> x = torch.randn(2, 5, 3) + >>> y = torch.randn(2, 6, 3) + >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) + tensor([2.8301, 3.0128]) + """ + _soft_dtw_validate_args(x, y, gamma) + return _soft_dtw_compute(x, y, gamma, distance_fn) \ No newline at end of file diff --git a/src/torchmetrics/timeseries/__init__.py b/src/torchmetrics/timeseries/__init__.py new file mode 100644 index 00000000000..298826554a6 --- /dev/null +++ b/src/torchmetrics/timeseries/__init__.py @@ -0,0 +1,16 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from torchmetrics.timeseries.softdtw import SoftDTW + +__all__ = ["SoftDTW"] diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py new file mode 100644 index 00000000000..ee3bad9f66b --- /dev/null +++ b/src/torchmetrics/timeseries/softdtw.py @@ -0,0 +1,146 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Any, Optional, Union, Callable, List + +import torch +from torch import Tensor +from typing_extensions import Literal + +from torchmetrics import Metric +from torchmetrics.functional.timeseries.softdtw import soft_dtw +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["SoftDTW.plot"] + + +class SoftDTW(Metric): + r"""Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + + Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + + This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by + Marco Cuturi and Mathieu Blondel (2017). + It replaces the hard minimum in DTW recursion with a *soft-minimum* using a log-sum-exp formulation: + + .. math:: + \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) + + The Soft-DTW recurrence is then defined as: + + .. math:: + R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) + + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. + + The final Soft-DTW distance is :math:`R_{N,M}`. + + Args: + gamma: Smoothing parameter (:math:`\gamma > 0`). Smaller values make the loss closer to standard DTW (hard minimum), + while larger values produce a smoother and more differentiable surface. + distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. + If ``None``, defaults to **squared Euclidean distance**. + + Raises: + ValueError: + If ``gamma`` is not a positive float. + If input tensors to ``update`` are not 3-dimensional with the same batch size and feature dimension. + + Example: + >>> from torch import randn + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW(gamma=0.1) + >>> x = randn(10, 50, 2) + >>> y = randn(10, 60, 2) + >>> metric(x, y) + tensor(43.2051) + """ + + full_state_update: bool = False + is_differentiable: bool = False + higher_is_better: bool = False + plot_lower_bound: float = 0.0 + plot_upper_bound: float = 1.0 + + pred_list: List[Tensor] + gt_list: List[Tensor] + + def __init__(self, + distance_fn: Optional[Callable] = None, + gamma: float = 1.0, + **kwargs: Any) -> None: + super().__init__(**kwargs) + self.distance_fn = distance_fn + if gamma <= 0: + raise ValueError(f"Argument `gamma` must be a positive float, got {gamma}") + self.gamma = gamma + + self.add_state("pred_list", default=[], dist_reduce_fx="cat") + self.add_state("gt_list", default=[], dist_reduce_fx="cat") + + def update(self, x: torch.Tensor, y: torch.Tensor) -> None: + """Update the Procrustes Disparity with the given datasets.""" + self.pred_list.append(x) + self.gt_list.append(y) + + def compute(self) -> torch.Tensor: + """Computes the Procrustes Disparity.""" + return soft_dtw( + torch.cat(self.pred_list, dim=0), + torch.cat(self.gt_list, dim=0), + gamma=self.gamma, + distance_fn=self.distance_fn, + ).mean() + + def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> # Example plotting a single value + >>> import torch + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW() + >>> metric.update(torch.randn(10, 100, 2), torch.randn(10, 50, 2)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> # Example plotting multiple values + >>> import torch + >>> from torchmetrics.timeseries import SoftDTW + >>> metric = SoftDTW() + >>> values = [ ] + >>> for _ in range(10): + ... values.append(metric(torch.randn(10, 100, 2), torch.randn(10, 50, 2))) + >>> fig_, ax_ = metric.plot(values) + + """ + return self._plot(val, ax) From 131c611d1b2cdbff1cbeebf39d1b75139222ad1c Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Tue, 7 Oct 2025 17:22:34 -0400 Subject: [PATCH 2/7] Modifying the implementation and adding initial testcases --- docs/source/index.rst | 8 ++ docs/source/links.rst | 1 + docs/source/timeseries/softdtw.rst | 22 +++++ .../functional/timeseries/softdtw.py | 64 ++++++++----- src/torchmetrics/timeseries/softdtw.py | 4 +- tests/unittests/timeseries/test_softdtw.py | 91 +++++++++++++++++++ tests/unittests/utilities/test_plot.py | 7 ++ 7 files changed, 173 insertions(+), 24 deletions(-) create mode 100644 docs/source/timeseries/softdtw.rst create mode 100644 tests/unittests/timeseries/test_softdtw.py diff --git a/docs/source/index.rst b/docs/source/index.rst index 7a8b1f5ff7d..bf766861875 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -253,6 +253,14 @@ Or directly from conda :caption: Video :glob: + timeseries/* + +.. toctree:: + :maxdepth: 2 + :name: timeseries + :caption: Time Series + :glob: + video/* .. toctree:: diff --git a/docs/source/links.rst b/docs/source/links.rst index 539d2728e74..de6701a57a9 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -184,3 +184,4 @@ .. _Deep Image Structure and Texture Similarity: https://arxiv.org/abs/2004.07728 .. _KonIQ-10k: https://database.mmsp-kn.de/koniq-10k-database.html .. _KADID-10k: https://database.mmsp-kn.de/kadid-10k-database.html +.. _SoftDTW: https://arxiv.org/abs/1703.01541 \ No newline at end of file diff --git a/docs/source/timeseries/softdtw.rst b/docs/source/timeseries/softdtw.rst new file mode 100644 index 00000000000..d8be653ba90 --- /dev/null +++ b/docs/source/timeseries/softdtw.rst @@ -0,0 +1,22 @@ +.. customcarditem:: + :header: Soft Dynamic Time Warping + :image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/tabular_classification.svg + :tags: timeseries + +.. include:: ../links.rst + +#################### +Soft Dynamic Time Warping +#################### + +Module Interface +________________ + +.. autoclass:: torchmetrics.timeseries.SoftDTW + :exclude-members: update, compute + + +Functional Interface +____________________ + +.. autofunction:: torchmetrics.functional.timeseries.soft_dtw diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index e783c35dee2..2e0928e2d9c 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -17,28 +17,31 @@ import torch from torch import Tensor -def _soft_dtw_validate_args(x: Tensor, y: Tensor, gamma: float) -> None: +def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None: """Validate the input arguments for the soft_dtw function.""" - if x.ndim != 3 or y.ndim != 3: - raise ValueError("Inputs x and y must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") - if x.shape[0] != y.shape[0]: - raise ValueError("Batch size of x and y must be the same.") - if x.shape[2] != y.shape[2]: - raise ValueError("Feature dimension of x and y must be the same.") + if preds.ndim != 3 or target.ndim != 3: + raise ValueError("Inputs preds and target must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") + if preds.shape[0] != target.shape[0]: + raise ValueError("Batch size of preds and target must be the same.") + if preds.shape[2] != target.shape[2]: + raise ValueError("Feature dimension of preds and target must be the same.") if not isinstance(gamma, float) or gamma <= 0: raise ValueError("Gamma must be a positive float.") -def _soft_dtw_compute(x: Tensor, y: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: +def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: """Compute the Soft-DTW distance between two batched sequences.""" - B, N, D = x.shape - _, M, _ = y.shape - device, dtype = x.device, x.dtype + + B, N, D = preds.shape + _, M, _ = target.shape + device, dtype = target.device, target.dtype + if preds.dtype != target.dtype: + target = target.to(preds.dtype) if distance_fn is None: def distance_fn(a, b): return torch.cdist(a, b, p=2).pow(2) - D = distance_fn(x, y) # [B, N, M] + D = distance_fn(preds, target) # [B, N, M] R = torch.ones((B, N + 2, M + 2), device=device, dtype=dtype) * math.inf R[:, 0, 0] = 0.0 @@ -47,18 +50,35 @@ def softmin(a, b, c, gamma): vals = torch.stack([a, b, c], dim=-1) return -gamma * torch.logsumexp(-vals / gamma, dim=-1) - for i in range(1, N + 1): - for j in range(1, M + 1): - r1 = R[:, i-1, j-1] - r2 = R[:, i-1, j] - r3 = R[:, i, j-1] - R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) + # Loop based implementation + # for i in range(1, N + 1): + # for j in range(1, M + 1): + # r1 = R[:, i-1, j-1] + # r2 = R[:, i-1, j] + # r3 = R[:, i, j-1] + # R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) + + # Anti-diagonal implementation + for k in range(2, N+M+1): + i_vals = torch.arange(1, N+1, device=device) + j_vals = k - i_vals + mask = (j_vals >= 1) & (j_vals <= M) + i_vals = i_vals[mask] + j_vals = j_vals[mask] + + if len(i_vals) == 0: + continue + + r1 = R[:, i_vals-1, j_vals-1] + r2 = R[:, i_vals-1, j_vals] + r3 = R[:, i_vals, j_vals-1] + R[:, i_vals, j_vals] = D[:, i_vals-1, j_vals-1] + softmin(r1, r2, r3, gamma) return R[:, N, M] def soft_dtw( - x: Tensor, - y: Tensor, + preds: Tensor, + target: Tensor, gamma: float = 1.0, distance_fn=None, ) -> Tensor: @@ -112,5 +132,5 @@ def soft_dtw( >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor([2.8301, 3.0128]) """ - _soft_dtw_validate_args(x, y, gamma) - return _soft_dtw_compute(x, y, gamma, distance_fn) \ No newline at end of file + _soft_dtw_validate_args(preds, target, gamma) + return _soft_dtw_compute(preds, target, gamma, distance_fn) \ No newline at end of file diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py index ee3bad9f66b..6a46d587c08 100644 --- a/src/torchmetrics/timeseries/softdtw.py +++ b/src/torchmetrics/timeseries/softdtw.py @@ -70,7 +70,7 @@ class SoftDTW(Metric): """ full_state_update: bool = False - is_differentiable: bool = False + is_differentiable: bool = True higher_is_better: bool = False plot_lower_bound: float = 0.0 plot_upper_bound: float = 1.0 @@ -103,7 +103,7 @@ def compute(self) -> torch.Tensor: torch.cat(self.gt_list, dim=0), gamma=self.gamma, distance_fn=self.distance_fn, - ).mean() + ) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: """Plot a single or multiple values from the metric. diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py new file mode 100644 index 00000000000..3c002e63070 --- /dev/null +++ b/tests/unittests/timeseries/test_softdtw.py @@ -0,0 +1,91 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from functools import partial +import pytest +import torch +import pysdtw +from torchmetrics.functional.timeseries.softdtw import soft_dtw +from torchmetrics.timeseries.softdtw import SoftDTW +from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests._helpers import seed_all +from unittests._helpers.testers import MetricTester + + +seed_all(42) + +num_batches = 1 +batch_size = 1 +_inputs = _Input( + preds=torch.randn(num_batches, batch_size, 15, 3, dtype=torch.float64), + target=torch.randn(num_batches, batch_size, 14, 3, dtype=torch.float64), +) + +def _reference_softdtw(preds: torch.Tensor, target: torch.Tensor, gamma: float = 1.0, distance_fn=None) -> torch.Tensor: + """Reference implementation using tslearn's soft-DTW.""" + preds = preds.to("cuda" if torch.cuda.is_available() else "cpu") + target = target.to("cuda" if torch.cuda.is_available() else "cpu") + sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=True if torch.cuda.is_available() else False) + return sdtw(preds, target) + +def euclidean_distance(x, y): + return torch.cdist(x, y, p=2) + +def manhattan_distance(x, y): + return torch.cdist(x, y, p=1) + +def cosine_distance(x, y): + x_norm = x / x.norm(dim=-1, keepdim=True) + y_norm = y / y.norm(dim=-1, keepdim=True) + return 1 - torch.matmul(x_norm, y_norm.transpose(-1, -2)) + +@pytest.mark.parametrize("preds, target", [(_inputs.preds, _inputs.target)]) +class TestSoftDTW(MetricTester): + """Test class for `SoftDTW` metric.""" + + @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_softdtw_class(self, gamma, preds, target, distance_fn, ddp): + """Test class implementation of SoftDTW.""" + self.run_class_metric_test( + ddp, + preds, + target, + SoftDTW, + partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn), + metric_args={"gamma": gamma, "distance_fn": distance_fn}, + ) + + @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) + @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) + def test_softdtw_functional(self, preds, target, gamma, distance_fn): + """Test functional implementation of SoftDTW.""" + self.run_functional_metric_test( + preds, + target, + metric_functional=soft_dtw, + reference_metric=partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn), + metric_args={"gamma": gamma, "distance_fn": distance_fn}, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_softdtw_differentiability(self, preds, target): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=SoftDTW, + metric_functional=soft_dtw, + metric_args={"gamma": 0.1}, + ) diff --git a/tests/unittests/utilities/test_plot.py b/tests/unittests/utilities/test_plot.py index 4d973a790c9..8deeca9322a 100644 --- a/tests/unittests/utilities/test_plot.py +++ b/tests/unittests/utilities/test_plot.py @@ -183,6 +183,7 @@ WordInfoLost, WordInfoPreserved, ) +from torchmetrics.timeseries import SoftDTW from torchmetrics.utilities.plot import _get_col_row_split from torchmetrics.wrappers import ( BootStrapper, @@ -692,6 +693,12 @@ lambda: torch.randn(10, 100, 3), id="lip vertex error", ), + pytest.param( + SoftDTW, + lambda: torch.randn(10, 4, 3), + lambda: torch.randn(10, 6, 3), + id="soft dtw", + ), ], ) @pytest.mark.parametrize("num_vals", [1, 3]) From 81be46d44474e39f5617b8bd07259cd481a1e3e7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 7 Oct 2025 21:31:19 +0000 Subject: [PATCH 3/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 12 +++++---- docs/source/links.rst | 2 +- .../functional/timeseries/__init__.py | 1 - .../functional/timeseries/softdtw.py | 27 ++++++++++--------- src/torchmetrics/timeseries/softdtw.py | 9 +++---- tests/unittests/timeseries/test_softdtw.py | 16 +++++++---- 6 files changed, 37 insertions(+), 30 deletions(-) diff --git a/README.md b/README.md index fd00bcaa6bd..c77320656fd 100644 --- a/README.md +++ b/README.md @@ -39,13 +39,15 @@ ______________________________________________________________________ # Looking for GPUs? -Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. -- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. -- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. + +Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning. + +- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19. +- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters. - [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train. -- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. +- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models. - [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze. -- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. +- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs. # Installation diff --git a/docs/source/links.rst b/docs/source/links.rst index de6701a57a9..3b89f00cd02 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -184,4 +184,4 @@ .. _Deep Image Structure and Texture Similarity: https://arxiv.org/abs/2004.07728 .. _KonIQ-10k: https://database.mmsp-kn.de/koniq-10k-database.html .. _KADID-10k: https://database.mmsp-kn.de/kadid-10k-database.html -.. _SoftDTW: https://arxiv.org/abs/1703.01541 \ No newline at end of file +.. _SoftDTW: https://arxiv.org/abs/1703.01541 diff --git a/src/torchmetrics/functional/timeseries/__init__.py b/src/torchmetrics/functional/timeseries/__init__.py index 5f140537f78..81ef5187653 100644 --- a/src/torchmetrics/functional/timeseries/__init__.py +++ b/src/torchmetrics/functional/timeseries/__init__.py @@ -14,4 +14,3 @@ from torchmetrics.functional.timeseries.softdtw import soft_dtw __all__ = ["soft_dtw"] - diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index 2e0928e2d9c..3d4d8c5ac46 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Callable import math +from typing import Callable, Optional import torch from torch import Tensor + def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None: """Validate the input arguments for the soft_dtw function.""" if preds.ndim != 3 or target.ndim != 3: @@ -28,9 +29,9 @@ def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None if not isinstance(gamma, float) or gamma <= 0: raise ValueError("Gamma must be a positive float.") + def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: """Compute the Soft-DTW distance between two batched sequences.""" - B, N, D = preds.shape _, M, _ = target.shape device, dtype = target.device, target.dtype @@ -38,6 +39,7 @@ def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: target = target.to(preds.dtype) if distance_fn is None: + def distance_fn(a, b): return torch.cdist(a, b, p=2).pow(2) @@ -59,8 +61,8 @@ def softmin(a, b, c, gamma): # R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) # Anti-diagonal implementation - for k in range(2, N+M+1): - i_vals = torch.arange(1, N+1, device=device) + for k in range(2, N + M + 1): + i_vals = torch.arange(1, N + 1, device=device) j_vals = k - i_vals mask = (j_vals >= 1) & (j_vals <= M) i_vals = i_vals[mask] @@ -69,21 +71,21 @@ def softmin(a, b, c, gamma): if len(i_vals) == 0: continue - r1 = R[:, i_vals-1, j_vals-1] - r2 = R[:, i_vals-1, j_vals] - r3 = R[:, i_vals, j_vals-1] - R[:, i_vals, j_vals] = D[:, i_vals-1, j_vals-1] + softmin(r1, r2, r3, gamma) + r1 = R[:, i_vals - 1, j_vals - 1] + r2 = R[:, i_vals - 1, j_vals] + r3 = R[:, i_vals, j_vals - 1] + R[:, i_vals, j_vals] = D[:, i_vals - 1, j_vals - 1] + softmin(r1, r2, r3, gamma) return R[:, N, M] + def soft_dtw( preds: Tensor, target: Tensor, gamma: float = 1.0, distance_fn=None, ) -> Tensor: - r""" - Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + r"""Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). @@ -120,7 +122,7 @@ def soft_dtw( >>> y = torch.tensor([[[0.0], [2.0], [3.0]]]) # [B, M, D] >>> soft_dtw(x, y, gamma=0.1) tensor([0.4003]) - + Example (custom distance function):: >>> def cosine_dist(a, b): ... a = torch.nn.functional.normalize(a, dim=-1) @@ -131,6 +133,7 @@ def soft_dtw( >>> y = torch.randn(2, 6, 3) >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor([2.8301, 3.0128]) + """ _soft_dtw_validate_args(preds, target, gamma) - return _soft_dtw_compute(preds, target, gamma, distance_fn) \ No newline at end of file + return _soft_dtw_compute(preds, target, gamma, distance_fn) diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py index 6a46d587c08..6416a0cb258 100644 --- a/src/torchmetrics/timeseries/softdtw.py +++ b/src/torchmetrics/timeseries/softdtw.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import Any, Optional, Union, Callable, List +from typing import Any, Callable, List, Optional, Union import torch from torch import Tensor -from typing_extensions import Literal from torchmetrics import Metric from torchmetrics.functional.timeseries.softdtw import soft_dtw @@ -67,6 +66,7 @@ class SoftDTW(Metric): >>> y = randn(10, 60, 2) >>> metric(x, y) tensor(43.2051) + """ full_state_update: bool = False @@ -78,10 +78,7 @@ class SoftDTW(Metric): pred_list: List[Tensor] gt_list: List[Tensor] - def __init__(self, - distance_fn: Optional[Callable] = None, - gamma: float = 1.0, - **kwargs: Any) -> None: + def __init__(self, distance_fn: Optional[Callable] = None, gamma: float = 1.0, **kwargs: Any) -> None: super().__init__(**kwargs) self.distance_fn = distance_fn if gamma <= 0: diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py index 3c002e63070..e17789dd5bc 100644 --- a/tests/unittests/timeseries/test_softdtw.py +++ b/tests/unittests/timeseries/test_softdtw.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial + +import pysdtw import pytest import torch -import pysdtw -from torchmetrics.functional.timeseries.softdtw import soft_dtw -from torchmetrics.timeseries.softdtw import SoftDTW -from unittests import BATCH_SIZE, NUM_BATCHES, _Input +from unittests import _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester +from torchmetrics.functional.timeseries.softdtw import soft_dtw +from torchmetrics.timeseries.softdtw import SoftDTW seed_all(42) @@ -31,6 +32,7 @@ target=torch.randn(num_batches, batch_size, 14, 3, dtype=torch.float64), ) + def _reference_softdtw(preds: torch.Tensor, target: torch.Tensor, gamma: float = 1.0, distance_fn=None) -> torch.Tensor: """Reference implementation using tslearn's soft-DTW.""" preds = preds.to("cuda" if torch.cuda.is_available() else "cpu") @@ -38,21 +40,25 @@ def _reference_softdtw(preds: torch.Tensor, target: torch.Tensor, gamma: float = sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=True if torch.cuda.is_available() else False) return sdtw(preds, target) + def euclidean_distance(x, y): return torch.cdist(x, y, p=2) + def manhattan_distance(x, y): return torch.cdist(x, y, p=1) + def cosine_distance(x, y): x_norm = x / x.norm(dim=-1, keepdim=True) y_norm = y / y.norm(dim=-1, keepdim=True) return 1 - torch.matmul(x_norm, y_norm.transpose(-1, -2)) + @pytest.mark.parametrize("preds, target", [(_inputs.preds, _inputs.target)]) class TestSoftDTW(MetricTester): """Test class for `SoftDTW` metric.""" - + @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) From c34c862f11b360814d37fd0b2e27a99807b240df Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Wed, 8 Oct 2025 13:19:41 -0400 Subject: [PATCH 4/7] Resolving pre-commit errors --- requirements/timeseries_test.txt | 1 + .../functional/timeseries/softdtw.py | 49 +++++++++++-------- src/torchmetrics/timeseries/softdtw.py | 3 +- tests/unittests/timeseries/test_softdtw.py | 7 ++- 4 files changed, 37 insertions(+), 23 deletions(-) create mode 100644 requirements/timeseries_test.txt diff --git a/requirements/timeseries_test.txt b/requirements/timeseries_test.txt new file mode 100644 index 00000000000..62d6031ab0f --- /dev/null +++ b/requirements/timeseries_test.txt @@ -0,0 +1 @@ +pysdtw==0.0.5 diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index 3d4d8c5ac46..ce5ecefc8ce 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -13,11 +13,13 @@ # limitations under the License. import math from typing import Callable, Optional +from typing import Callable, Optional import torch from torch import Tensor + def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None: """Validate the input arguments for the soft_dtw function.""" if preds.ndim != 3 or target.ndim != 3: @@ -30,25 +32,28 @@ def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None raise ValueError("Gamma must be a positive float.") + def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: """Compute the Soft-DTW distance between two batched sequences.""" - B, N, D = preds.shape - _, M, _ = target.shape + b, n, d = preds.shape + _, m, _ = target.shape device, dtype = target.device, target.dtype if preds.dtype != target.dtype: target = target.to(preds.dtype) if distance_fn is None: - def distance_fn(a, b): - return torch.cdist(a, b, p=2).pow(2) + def distance_fn(x: Tensor, y: Tensor) -> Tensor: + """Default to squared Euclidean distance.""" + return torch.cdist(x, y, p=2).pow(2) - D = distance_fn(preds, target) # [B, N, M] + distances = distance_fn(preds, target) # [B, N, M] - R = torch.ones((B, N + 2, M + 2), device=device, dtype=dtype) * math.inf - R[:, 0, 0] = 0.0 + r = torch.ones((b, n + 2, m + 2), device=device, dtype=dtype) * math.inf + r[:, 0, 0] = 0.0 - def softmin(a, b, c, gamma): + def softmin(a: Tensor, b: Tensor, c: Tensor, gamma: float) -> Tensor: + """Compute the soft minimum of three tensors.""" vals = torch.stack([a, b, c], dim=-1) return -gamma * torch.logsumexp(-vals / gamma, dim=-1) @@ -61,29 +66,29 @@ def softmin(a, b, c, gamma): # R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) # Anti-diagonal implementation - for k in range(2, N + M + 1): - i_vals = torch.arange(1, N + 1, device=device) + for k in range(2, n + m + 1): + i_vals = torch.arange(1, n + 1, device=device) j_vals = k - i_vals - mask = (j_vals >= 1) & (j_vals <= M) + mask = (j_vals >= 1) & (j_vals <= m) i_vals = i_vals[mask] j_vals = j_vals[mask] if len(i_vals) == 0: continue - r1 = R[:, i_vals - 1, j_vals - 1] - r2 = R[:, i_vals - 1, j_vals] - r3 = R[:, i_vals, j_vals - 1] - R[:, i_vals, j_vals] = D[:, i_vals - 1, j_vals - 1] + softmin(r1, r2, r3, gamma) + r1 = r[:, i_vals - 1, j_vals - 1] + r2 = r[:, i_vals - 1, j_vals] + r3 = r[:, i_vals, j_vals - 1] + r[:, i_vals, j_vals] = distances[:, i_vals - 1, j_vals - 1] + softmin(r1, r2, r3, gamma) - return R[:, N, M] + return r[:, n, m] def soft_dtw( preds: Tensor, target: Tensor, gamma: float = 1.0, - distance_fn=None, + distance_fn: Optional[Callable] = None, ) -> Tensor: r"""Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. @@ -104,9 +109,10 @@ def soft_dtw( The final Soft-DTW distance is :math:`R_{N,M}`. Args: - x: Tensor of shape ``[B, N, D]`` — batch of input sequences. - y: Tensor of shape ``[B, M, D]`` — batch of target sequences. - gamma: Smoothing parameter (:math:`\gamma > 0`). Smaller values make the loss closer to standard DTW (hard minimum), + preds: Tensor of shape ``[B, N, D]`` — batch of input sequences. + target: Tensor of shape ``[B, M, D]`` — batch of target sequences. + gamma: Smoothing parameter (:math:`\gamma > 0`). + Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface. distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. If ``None``, defaults to **squared Euclidean distance**. @@ -123,6 +129,7 @@ def soft_dtw( >>> soft_dtw(x, y, gamma=0.1) tensor([0.4003]) + Example (custom distance function):: >>> def cosine_dist(a, b): ... a = torch.nn.functional.normalize(a, dim=-1) @@ -134,6 +141,8 @@ def soft_dtw( >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor([2.8301, 3.0128]) + """ _soft_dtw_validate_args(preds, target, gamma) return _soft_dtw_compute(preds, target, gamma, distance_fn) + diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py index 6416a0cb258..2055ce5dcb9 100644 --- a/src/torchmetrics/timeseries/softdtw.py +++ b/src/torchmetrics/timeseries/softdtw.py @@ -48,7 +48,8 @@ class SoftDTW(Metric): The final Soft-DTW distance is :math:`R_{N,M}`. Args: - gamma: Smoothing parameter (:math:`\gamma > 0`). Smaller values make the loss closer to standard DTW (hard minimum), + gamma: Smoothing parameter (:math:`\gamma > 0`). + Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface. distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. If ``None``, defaults to **squared Euclidean distance**. diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py index e17789dd5bc..83075ff2d7c 100644 --- a/tests/unittests/timeseries/test_softdtw.py +++ b/tests/unittests/timeseries/test_softdtw.py @@ -37,25 +37,28 @@ def _reference_softdtw(preds: torch.Tensor, target: torch.Tensor, gamma: float = """Reference implementation using tslearn's soft-DTW.""" preds = preds.to("cuda" if torch.cuda.is_available() else "cpu") target = target.to("cuda" if torch.cuda.is_available() else "cpu") - sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=True if torch.cuda.is_available() else False) + sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=bool(torch.cuda.is_available())) return sdtw(preds, target) def euclidean_distance(x, y): + """Squared Euclidean distance.""" return torch.cdist(x, y, p=2) def manhattan_distance(x, y): + """L1 (Manhattan) distance.""" return torch.cdist(x, y, p=1) def cosine_distance(x, y): + """Cosine distance.""" x_norm = x / x.norm(dim=-1, keepdim=True) y_norm = y / y.norm(dim=-1, keepdim=True) return 1 - torch.matmul(x_norm, y_norm.transpose(-1, -2)) -@pytest.mark.parametrize("preds, target", [(_inputs.preds, _inputs.target)]) +@pytest.mark.parametrize(("preds", "target"), [(_inputs.preds, _inputs.target)]) class TestSoftDTW(MetricTester): """Test class for `SoftDTW` metric.""" From 89aee8b77f0b50d6be0928b33a160a2f29f17744 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:22:01 +0000 Subject: [PATCH 5/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/timeseries/softdtw.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index ce5ecefc8ce..f52f45d4d9e 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -13,13 +13,11 @@ # limitations under the License. import math from typing import Callable, Optional -from typing import Callable, Optional import torch from torch import Tensor - def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None: """Validate the input arguments for the soft_dtw function.""" if preds.ndim != 3 or target.ndim != 3: @@ -32,7 +30,6 @@ def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None raise ValueError("Gamma must be a positive float.") - def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: """Compute the Soft-DTW distance between two batched sequences.""" b, n, d = preds.shape @@ -141,8 +138,6 @@ def soft_dtw( >>> soft_dtw(x, y, gamma=0.5, distance_fn=cosine_dist) tensor([2.8301, 3.0128]) - """ _soft_dtw_validate_args(preds, target, gamma) return _soft_dtw_compute(preds, target, gamma, distance_fn) - From 80a63fabdae329fc97435a2b350758e14803e23e Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Thu, 9 Oct 2025 11:34:19 -0400 Subject: [PATCH 6/7] Adding more tests and cleaning up the code --- docs/source/timeseries/softdtw.rst | 4 +- .../functional/timeseries/softdtw.py | 10 +--- src/torchmetrics/timeseries/softdtw.py | 4 ++ tests/unittests/timeseries/test_softdtw.py | 47 ++++++++++++++++--- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/docs/source/timeseries/softdtw.rst b/docs/source/timeseries/softdtw.rst index d8be653ba90..d2d99958645 100644 --- a/docs/source/timeseries/softdtw.rst +++ b/docs/source/timeseries/softdtw.rst @@ -5,9 +5,9 @@ .. include:: ../links.rst -#################### +######################### Soft Dynamic Time Warping -#################### +######################### Module Interface ________________ diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index f52f45d4d9e..b6a61892998 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -54,15 +54,7 @@ def softmin(a: Tensor, b: Tensor, c: Tensor, gamma: float) -> Tensor: vals = torch.stack([a, b, c], dim=-1) return -gamma * torch.logsumexp(-vals / gamma, dim=-1) - # Loop based implementation - # for i in range(1, N + 1): - # for j in range(1, M + 1): - # r1 = R[:, i-1, j-1] - # r2 = R[:, i-1, j] - # r3 = R[:, i, j-1] - # R[:, i, j] = D[:, i-1, j-1] + softmin(r1, r2, r3, gamma) - - # Anti-diagonal implementation + # Anti-diagonal approach inspired from https://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=8400444 for k in range(2, n + m + 1): i_vals = torch.arange(1, n + 1, device=device) j_vals = k - i_vals diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py index 2055ce5dcb9..96603e42fc4 100644 --- a/src/torchmetrics/timeseries/softdtw.py +++ b/src/torchmetrics/timeseries/softdtw.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import warnings from collections.abc import Sequence from typing import Any, Callable, List, Optional, Union @@ -86,6 +87,9 @@ def __init__(self, distance_fn: Optional[Callable] = None, gamma: float = 1.0, * raise ValueError(f"Argument `gamma` must be a positive float, got {gamma}") self.gamma = gamma + if self.device.type == "cpu": # warn on cpu + warnings.warn("SoftDTW is slow on CPU. Consider using a GPU.", stacklevel=2) + self.add_state("pred_list", default=[], dist_reduce_fx="cat") self.add_state("gt_list", default=[], dist_reduce_fx="cat") diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py index 83075ff2d7c..8b9a5bdc9fc 100644 --- a/tests/unittests/timeseries/test_softdtw.py +++ b/tests/unittests/timeseries/test_softdtw.py @@ -16,7 +16,7 @@ import pysdtw import pytest import torch -from unittests import _Input +from unittests import BATCH_SIZE, NUM_BATCHES, _Input from unittests._helpers import seed_all from unittests._helpers.testers import MetricTester @@ -25,11 +25,9 @@ seed_all(42) -num_batches = 1 -batch_size = 1 _inputs = _Input( - preds=torch.randn(num_batches, batch_size, 15, 3, dtype=torch.float64), - target=torch.randn(num_batches, batch_size, 14, 3, dtype=torch.float64), + preds=torch.randn(NUM_BATCHES, BATCH_SIZE, 20, 3, dtype=torch.float64), + target=torch.randn(NUM_BATCHES, BATCH_SIZE, 30, 3, dtype=torch.float64), ) @@ -88,7 +86,7 @@ def test_softdtw_functional(self, preds, target, gamma, distance_fn): metric_args={"gamma": gamma, "distance_fn": distance_fn}, ) - @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") def test_softdtw_differentiability(self, preds, target): """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" self.run_differentiability_test( @@ -96,5 +94,40 @@ def test_softdtw_differentiability(self, preds, target): target=target, metric_module=SoftDTW, metric_functional=soft_dtw, - metric_args={"gamma": 0.1}, + metric_args={"gamma": 1.0}, ) + + +def test_wrong_dimensions(): + """Test that an error is raised if input tensors have wrong dimensions.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Inputs preds and target must be 3-dimensional tensors of shape*"): + metric(torch.randn(10, 100), torch.randn(10, 100, 3)) + + +def test_mismatched_dimensions(): + """Test that an error is raised if input dimensions don't match.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Batch size of preds and target must be the same.*"): + metric(torch.randn(10, 80, 3), torch.randn(12, 100, 3)) + + +def test_mismatched_feature_dimensions(): + """Test that an error is raised if input feature dimensions don't match.""" + metric = SoftDTW() + with pytest.raises(ValueError, match="Feature dimension of preds and target must be the same.*"): + metric(torch.randn(10, 80, 3), torch.randn(10, 100, 4)) + + +def test_invalid_gamma(): + """Test that an error is raised if gamma is not a positive float.""" + with pytest.raises(ValueError, match="Argument `gamma` must be a positive float, got -1.0*"): + SoftDTW(gamma=-1.0) + + +def test_warning_on_cpu(): + """Test that a warning is raised if SoftDTW is used on CPU.""" + if torch.cuda.is_available(): + pytest.skip("Test only runs on CPU.") + with pytest.warns(UserWarning, match="SoftDTW is slow on CPU. Consider using a GPU.*"): + SoftDTW() From f22fbca6b613ae1af4156de75e31ea454c56ead0 Mon Sep 17 00:00:00 2001 From: VijayVignesh1 Date: Thu, 9 Oct 2025 15:05:01 -0400 Subject: [PATCH 7/7] Adding reduction parameter over batch dimension --- .../functional/timeseries/softdtw.py | 46 +++++++++++++++---- src/torchmetrics/timeseries/softdtw.py | 33 +++++++++---- tests/unittests/timeseries/test_softdtw.py | 28 ++++++++--- 3 files changed, 82 insertions(+), 25 deletions(-) diff --git a/src/torchmetrics/functional/timeseries/softdtw.py b/src/torchmetrics/functional/timeseries/softdtw.py index b6a61892998..74964e154d1 100644 --- a/src/torchmetrics/functional/timeseries/softdtw.py +++ b/src/torchmetrics/functional/timeseries/softdtw.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. import math -from typing import Callable, Optional +from typing import Callable, Literal, Optional import torch from torch import Tensor -def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None: +def _soft_dtw_validate_args( + preds: Tensor, target: Tensor, gamma: float, reduction: Literal["mean", "sum", "none"] +) -> None: """Validate the input arguments for the soft_dtw function.""" + valid_reduction = ("mean", "sum", "none") + if reduction not in valid_reduction: + raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") if preds.ndim != 3 or target.ndim != 3: raise ValueError("Inputs preds and target must be 3-dimensional tensors of shape [B, N, D] and [B, M, D].") if preds.shape[0] != target.shape[0]: @@ -30,7 +35,7 @@ def _soft_dtw_validate_args(preds: Tensor, target: Tensor, gamma: float) -> None raise ValueError("Gamma must be a positive float.") -def _soft_dtw_compute(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: +def _soft_dtw_update(preds: Tensor, target: Tensor, gamma: float, distance_fn: Optional[Callable] = None) -> Tensor: """Compute the Soft-DTW distance between two batched sequences.""" b, n, d = preds.shape _, m, _ = target.shape @@ -73,17 +78,27 @@ def softmin(a: Tensor, b: Tensor, c: Tensor, gamma: float) -> Tensor: return r[:, n, m] +def _soft_dtw_compute(scores: Tensor, reduction: Literal["sum", "mean", "none"] = "mean") -> Tensor: + """Aggregate the computed Soft-DTW distances based on the specified reduction method.""" + if reduction == "none": + return scores + if reduction == "mean": + return scores.mean() + return scores.sum() + + def soft_dtw( preds: Tensor, target: Tensor, gamma: float = 1.0, distance_fn: Optional[Callable] = None, + reduction: Literal["sum", "mean", "none"] = "mean", ) -> Tensor: - r"""Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + r"""Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences. This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). - It replaces the hard minimum in DTW recursion with a *soft-minimum* using a log-sum-exp formulation: + It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation: .. math:: \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) @@ -93,7 +108,8 @@ def soft_dtw( .. math:: R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) - where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. It could be + computed using any differentiable distance function, such as squared Euclidean distance or cosine distance. The final Soft-DTW distance is :math:`R_{N,M}`. @@ -104,11 +120,22 @@ def soft_dtw( Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface. distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. - If ``None``, defaults to **squared Euclidean distance**. + If ``None``, defaults to squared Euclidean distance. + reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. + Defaults to ``mean``. Returns: A tensor of shape ``[B]`` containing the Soft-DTW distance for each sequence pair in the batch. + Raises: + ValueError: + If ``reduction`` is not one of [``sum``, ``mean``, ``none``]. + ValueError: + If ``gamma`` is not a positive float. + ValueError: + If input tensors to ``preds`` and ``target`` are not 3-dimensional + with the same batch size and feature dimension. + Example:: >>> import torch >>> from torchmetrics.functional.timeseries import soft_dtw @@ -131,5 +158,6 @@ def soft_dtw( tensor([2.8301, 3.0128]) """ - _soft_dtw_validate_args(preds, target, gamma) - return _soft_dtw_compute(preds, target, gamma, distance_fn) + _soft_dtw_validate_args(preds, target, gamma, reduction) + scores = _soft_dtw_update(preds, target, gamma, distance_fn) + return _soft_dtw_compute(scores, reduction) diff --git a/src/torchmetrics/timeseries/softdtw.py b/src/torchmetrics/timeseries/softdtw.py index 96603e42fc4..7878528663d 100644 --- a/src/torchmetrics/timeseries/softdtw.py +++ b/src/torchmetrics/timeseries/softdtw.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from collections.abc import Sequence -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Literal, Optional, Union import torch from torch import Tensor @@ -28,13 +28,11 @@ class SoftDTW(Metric): - r"""Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. - - Compute the **Soft Dynamic Time Warping (Soft-DTW)** distance between two batched sequences. + r"""Compute the Soft Dynamic Time Warping (Soft-DTW) distance between two batched sequences. This is a differentiable relaxation of the classic Dynamic Time Warping (DTW) algorithm, introduced by Marco Cuturi and Mathieu Blondel (2017). - It replaces the hard minimum in DTW recursion with a *soft-minimum* using a log-sum-exp formulation: + It replaces the hard minimum in DTW recursion with a soft-minimum using a log-sum-exp formulation: .. math:: \text{softmin}_\gamma(a,b,c) = -\gamma \log \left( e^{-a/\gamma} + e^{-b/\gamma} + e^{-c/\gamma} \right) @@ -44,7 +42,8 @@ class SoftDTW(Metric): .. math:: R_{i,j} = D_{i,j} + \text{softmin}_\gamma(R_{i-1,j}, R_{i,j-1}, R_{i-1,j-1}) - where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. + where :math:`D_{i,j}` is the pairwise distance between sequence elements :math:`x_i` and :math:`y_j`. It could be + computed using any differentiable distance function, such as squared Euclidean distance or cosine distance. The final Soft-DTW distance is :math:`R_{N,M}`. @@ -53,12 +52,17 @@ class SoftDTW(Metric): Smaller values make the loss closer to standard DTW (hard minimum), while larger values produce a smoother and more differentiable surface. distance_fn: Optional callable ``(x, y) -> [B, N, M]`` defining the pairwise distance matrix. - If ``None``, defaults to **squared Euclidean distance**. + If ``None``, defaults to squared Euclidean distance. + reduction: indicates how to reduce over the batch dimension. Choose between [``sum``, ``mean``, ``none``]. Raises: + ValueError: + If ``reduction`` is not one of [``sum``, ``mean``, ``none``]. ValueError: If ``gamma`` is not a positive float. - If input tensors to ``update`` are not 3-dimensional with the same batch size and feature dimension. + ValueError: + If input tensors to ``update`` are not 3-dimensional + with the same batch size and feature dimension. Example: >>> from torch import randn @@ -80,12 +84,22 @@ class SoftDTW(Metric): pred_list: List[Tensor] gt_list: List[Tensor] - def __init__(self, distance_fn: Optional[Callable] = None, gamma: float = 1.0, **kwargs: Any) -> None: + def __init__( + self, + distance_fn: Optional[Callable] = None, + gamma: float = 1.0, + reduction: Literal["sum", "mean", "none"] = "mean", + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self.distance_fn = distance_fn + valid_reduction = ("mean", "sum", "none") + if reduction not in valid_reduction: + raise ValueError(f"Argument `reduction` must be one of {valid_reduction}, but got {reduction}") if gamma <= 0: raise ValueError(f"Argument `gamma` must be a positive float, got {gamma}") self.gamma = gamma + self.reduction = reduction if self.device.type == "cpu": # warn on cpu warnings.warn("SoftDTW is slow on CPU. Consider using a GPU.", stacklevel=2) @@ -105,6 +119,7 @@ def compute(self) -> torch.Tensor: torch.cat(self.gt_list, dim=0), gamma=self.gamma, distance_fn=self.distance_fn, + reduction=self.reduction, ) def plot(self, val: Union[Tensor, Sequence[Tensor], None] = None, ax: Optional[_AX_TYPE] = None) -> _PLOT_OUT_TYPE: diff --git a/tests/unittests/timeseries/test_softdtw.py b/tests/unittests/timeseries/test_softdtw.py index 8b9a5bdc9fc..9e26ac0a7ee 100644 --- a/tests/unittests/timeseries/test_softdtw.py +++ b/tests/unittests/timeseries/test_softdtw.py @@ -31,11 +31,17 @@ ) -def _reference_softdtw(preds: torch.Tensor, target: torch.Tensor, gamma: float = 1.0, distance_fn=None) -> torch.Tensor: +def _reference_softdtw( + preds: torch.Tensor, target: torch.Tensor, gamma: float = 1.0, distance_fn=None, reduction: str = "mean" +) -> torch.Tensor: """Reference implementation using tslearn's soft-DTW.""" preds = preds.to("cuda" if torch.cuda.is_available() else "cpu") target = target.to("cuda" if torch.cuda.is_available() else "cpu") sdtw = pysdtw.SoftDTW(gamma=gamma, dist_func=distance_fn, use_cuda=bool(torch.cuda.is_available())) + if reduction == "mean": + return sdtw(preds, target).mean() + if reduction == "sum": + return sdtw(preds, target).sum() return sdtw(preds, target) @@ -62,28 +68,30 @@ class TestSoftDTW(MetricTester): @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) - def test_softdtw_class(self, gamma, preds, target, distance_fn, ddp): + def test_softdtw_class(self, gamma, preds, target, distance_fn, reduction, ddp): """Test class implementation of SoftDTW.""" self.run_class_metric_test( ddp, preds, target, SoftDTW, - partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn), - metric_args={"gamma": gamma, "distance_fn": distance_fn}, + partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn, reduction=reduction), + metric_args={"gamma": gamma, "distance_fn": distance_fn, "reduction": reduction}, ) @pytest.mark.parametrize("gamma", [0.1, 0.5, 1.0]) @pytest.mark.parametrize("distance_fn", [euclidean_distance, manhattan_distance, cosine_distance]) - def test_softdtw_functional(self, preds, target, gamma, distance_fn): + @pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) + def test_softdtw_functional(self, preds, target, gamma, distance_fn, reduction): """Test functional implementation of SoftDTW.""" self.run_functional_metric_test( preds, target, metric_functional=soft_dtw, - reference_metric=partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn), - metric_args={"gamma": gamma, "distance_fn": distance_fn}, + reference_metric=partial(_reference_softdtw, gamma=gamma, distance_fn=distance_fn, reduction=reduction), + metric_args={"gamma": gamma, "distance_fn": distance_fn, "reduction": reduction}, ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="test is too slow without gpu") @@ -131,3 +139,9 @@ def test_warning_on_cpu(): pytest.skip("Test only runs on CPU.") with pytest.warns(UserWarning, match="SoftDTW is slow on CPU. Consider using a GPU.*"): SoftDTW() + + +def test_invalid_reduction(): + """Test that an error is raised if reduction is not one of [``sum``, ``mean``, ``none``].""" + with pytest.raises(ValueError, match="Argument `reduction` must be one of .*"): + SoftDTW(reduction="invalid")