diff --git a/docs/source/links.rst b/docs/source/links.rst index 539d2728e74..863ab57bac4 100644 --- a/docs/source/links.rst +++ b/docs/source/links.rst @@ -184,3 +184,6 @@ .. _Deep Image Structure and Texture Similarity: https://arxiv.org/abs/2004.07728 .. _KonIQ-10k: https://database.mmsp-kn.de/koniq-10k-database.html .. _KADID-10k: https://database.mmsp-kn.de/kadid-10k-database.html +.. _Algorithms_for_calculating_variance: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +.. _online_weighted_variance: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm +.. _online_weighted_covariance: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Weighted_batched_version diff --git a/src/torchmetrics/functional/regression/__init__.py b/src/torchmetrics/functional/regression/__init__.py index 9727d4fdd8f..7bc045f14c0 100644 --- a/src/torchmetrics/functional/regression/__init__.py +++ b/src/torchmetrics/functional/regression/__init__.py @@ -32,6 +32,7 @@ from torchmetrics.functional.regression.spearman import spearman_corrcoef from torchmetrics.functional.regression.symmetric_mape import symmetric_mean_absolute_percentage_error from torchmetrics.functional.regression.tweedie_deviance import tweedie_deviance_score +from torchmetrics.functional.regression.weighted_pearson import weighted_pearson_corrcoef from torchmetrics.functional.regression.wmape import weighted_mean_absolute_percentage_error __all__ = [ @@ -58,4 +59,5 @@ "symmetric_mean_absolute_percentage_error", "tweedie_deviance_score", "weighted_mean_absolute_percentage_error", + "weighted_pearson_corrcoef", ] diff --git a/src/torchmetrics/functional/regression/utils.py b/src/torchmetrics/functional/regression/utils.py index 59612927f26..6e22e9ab17c 100644 --- a/src/torchmetrics/functional/regression/utils.py +++ b/src/torchmetrics/functional/regression/utils.py @@ -41,3 +41,26 @@ def _check_data_shape_to_num_outputs( f"Expected argument `num_outputs` to match the second dimension of input, but got {num_outputs}" f" and {preds.shape[1]}." ) + + +def _check_data_shape_to_weights(preds: Tensor, weights: Tensor) -> None: + """Check that the predictions and weights have the correct shape, else raise error. + + This test assumes that the prediction and target tensors have been confirmed to have the same shape. + It further assumes that the `preds` is either a 1- or 2-dimensional tensor. + + Args: + preds: Prediction tensor + weights: Weight tensor + + """ + if weights.ndim != 1: + raise ValueError(f"Expected `weights` to be 1-d Tensor, but got {weights.ndim}-dim Tensor.") + if preds.ndim == 1 and preds.shape != weights.shape: + raise ValueError( + f"Expected `preds.shape` to equal to `weights.shape`, but got {preds.shape} and {weights.shape}." + ) + if preds.ndim == 2 and preds.shape[0] != len(weights): + raise ValueError( + f"Expected `preds.shape[0]` to equal to `len(weights)` but got {preds.shape[0]} and {len(weights)}." + ) diff --git a/src/torchmetrics/functional/regression/weighted_pearson.py b/src/torchmetrics/functional/regression/weighted_pearson.py new file mode 100644 index 00000000000..aece123547b --- /dev/null +++ b/src/torchmetrics/functional/regression/weighted_pearson.py @@ -0,0 +1,181 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math + +import torch +from torch import Tensor + +from torchmetrics.functional.regression.utils import _check_data_shape_to_num_outputs, _check_data_shape_to_weights +from torchmetrics.utilities import rank_zero_warn +from torchmetrics.utilities.checks import _check_same_shape + + +def _weighted_pearson_corrcoef_update( + preds: Tensor, + target: Tensor, + weights: Tensor, + mean_x: Tensor, + mean_y: Tensor, + var_x: Tensor, + var_y: Tensor, + cov_xy: Tensor, + weights_prior: Tensor, + num_outputs: int, +) -> tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: + """Update and returns variables required to compute weighted Pearson Correlation Coefficient. + + Check for same shape of input tensors. + + Updates are based on `Algorithms for calculating variance`_. Specifically, `online weighted variance`_ and + `online weighted covariance`_. + + Variance intentionally not divided by sum of weights in `update` step as it is computed as necessary in + the `compute` step. + + Args: + preds: estimated scores + target: ground truth scores + weights: weight associated with scores + mean_x: current mean estimate of x tensor + mean_y: current mean estimate of y tensor + var_x: current variance estimate of x tensor + var_y: current variance estimate of y tensor + cov_xy: current covariance estimate between x and y tensor + weights_prior: current sum of weights + num_outputs: number of outputs in multioutput setting + + """ + # Data checking + _check_same_shape(preds, target) + _check_data_shape_to_num_outputs(preds, target, num_outputs) + _check_data_shape_to_weights(preds, weights) + + if preds.ndim == 2: + weights = weights.unsqueeze(1) # singleton dimension for broadcasting + + weights_sum = weights.sum() + weights_new = weights_prior + weights_sum + + if weights_prior > 0: # True if prior observations exist + mx_new = mean_x + (weights * (preds - mean_x)).sum(0) / weights_new + my_new = mean_y + (weights * (target - mean_y)).sum(0) / weights_new + + var_x += (weights * (preds - mx_new) * (preds - mean_x)).sum(0) + var_y += (weights * (target - my_new) * (target - mean_y)).sum(0) + else: + mx_new = ((weights * preds).sum(0) / weights_sum).to(mean_x.dtype) + my_new = ((weights * target).sum(0) / weights_sum).to(mean_y.dtype) + + var_x = (weights * (preds - mx_new) ** 2).sum(0) + var_y = (weights * (target - my_new) ** 2).sum(0) + + # cov_xy += (weights * (preds - mx_new) * (target - my_new)).sum(0) + cov_xy += (weights * (preds - mx_new) * (target - mean_y)).sum(0) + + return mx_new, my_new, var_x, var_y, cov_xy, weights_new + + +def _weighted_pearson_corrcoef_compute( + var_x: Tensor, + var_y: Tensor, + cov_xy: Tensor, + weights_sum: Tensor, +) -> Tensor: + """Compute the final weighted Pearson correlation based on accumulated statistics. + + Args: + var_x: variance estimate of x tensor + var_y: variance estimate of y tensor + cov_xy: covariance estimate between x and y tensor + weights_sum: sum of weights + + """ + # prevent overwrite the inputs + var_x = var_x / weights_sum + var_y = var_y / weights_sum + cov_xy = cov_xy / weights_sum + + # if var_x, var_y is float16 and on cpu, make it bfloat16 as sqrt is not supported for float16 + # on cpu, remove this after https://github.com/pytorch/pytorch/issues/54774 is fixed + if var_x.dtype == torch.float16 and var_x.device == torch.device("cpu"): + var_x = var_x.bfloat16() + var_y = var_y.bfloat16() + + bound = math.sqrt(torch.finfo(var_x.dtype).eps) + if (var_x < bound).any() or (var_y < bound).any(): + rank_zero_warn( + "The variance of predictions or target is close to zero. This can cause instability in Pearson correlation" + "coefficient, leading to wrong results. Consider re-scaling the input if possible or computing using a" + f"larger dtype (currently using {var_x.dtype}). Setting the correlation coefficient to nan.", + UserWarning, + ) + + zero_var_mask = (var_x < bound) | (var_y < bound) + corrcoef = torch.full_like(cov_xy, float("nan"), device=cov_xy.device, dtype=cov_xy.dtype) + valid_mask = ~zero_var_mask + + if valid_mask.any(): + corrcoef[valid_mask] = ( + (cov_xy[valid_mask] / (var_x[valid_mask] * var_y[valid_mask]).sqrt()).squeeze().to(corrcoef.dtype) + ) + corrcoef = torch.clamp(corrcoef, -1.0, 1.0) + + return corrcoef.squeeze() + + +def weighted_pearson_corrcoef(preds: Tensor, target: Tensor, weights: Tensor) -> Tensor: + """Compute weighted Pearson correlation coefficient. + + Args: + preds: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs) + Estimate scores + target: torch.Tensor of shape (n_samples,) or (n_samples, n_outputs) + Ground truth scores + weights: torch.Tensor of shape (n_samples,) + Sample weights + + Example (single output weighted regression): + >>> from torchmetrics.functional.regression import weighted_pearson_corrcoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> weights = torch.tensor([0.1, 0.2, 0.5, 0.1]) + >>> weighted_pearson_corrcoef(preds, target, weights) + tensor(0.9837) + + Example (multi output weighted regression): + >>> from torchmetrics.functional.regression import weighted_pearson_corrcoef + >>> target = torch.tensor([[3, -0.5], [2, 7], [-1, 1.5]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8], [0.0, 1.3]]) + >>> weights = torch.tensor([0.3, 0.2, 0.5]) + >>> weighted_pearson_corrcoef(preds, target, weights) + tensor([0.9992, 0.9902]) + + """ + d = preds.shape[1] if preds.ndim == 2 else 1 + _temp = torch.zeros(d, dtype=preds.dtype, device=preds.device) + mean_x, mean_y, var_x = _temp.clone(), _temp.clone(), _temp.clone() + var_y, corr_xy, weights_sum = _temp.clone(), _temp.clone(), _temp.clone().sum() + _, _, var_x, var_y, corr_xy, weights_sum = _weighted_pearson_corrcoef_update( + preds, + target, + weights, + mean_x, + mean_y, + var_x, + var_y, + corr_xy, + weights_sum, + num_outputs=1 if preds.ndim == 1 else preds.shape[-1], + ) + return _weighted_pearson_corrcoef_compute(var_x, var_y, corr_xy, weights_sum) diff --git a/src/torchmetrics/regression/__init__.py b/src/torchmetrics/regression/__init__.py index a644fc58897..1feddb7a244 100644 --- a/src/torchmetrics/regression/__init__.py +++ b/src/torchmetrics/regression/__init__.py @@ -32,6 +32,7 @@ from torchmetrics.regression.spearman import SpearmanCorrCoef from torchmetrics.regression.symmetric_mape import SymmetricMeanAbsolutePercentageError from torchmetrics.regression.tweedie_deviance import TweedieDevianceScore +from torchmetrics.regression.weighted_pearson import WeightedPearsonCorrCoef from torchmetrics.regression.wmape import WeightedMeanAbsolutePercentageError __all__ = [ @@ -57,4 +58,5 @@ "SymmetricMeanAbsolutePercentageError", "TweedieDevianceScore", "WeightedMeanAbsolutePercentageError", + "WeightedPearsonCorrCoef", ] diff --git a/src/torchmetrics/regression/weighted_pearson.py b/src/torchmetrics/regression/weighted_pearson.py new file mode 100644 index 00000000000..fc34d951f85 --- /dev/null +++ b/src/torchmetrics/regression/weighted_pearson.py @@ -0,0 +1,185 @@ +# Copyright The Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence +from typing import Any, List, Optional, Union + +import torch +from torch import Tensor + +from torchmetrics.functional.regression.weighted_pearson import ( + _weighted_pearson_corrcoef_compute, + _weighted_pearson_corrcoef_update, +) +from torchmetrics.metric import Metric +from torchmetrics.regression.pearson import _final_aggregation +from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE +from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE + +if not _MATPLOTLIB_AVAILABLE: + __doctest_skip__ = ["WeightedPearsonCorrCoef.plot"] + + +class WeightedPearsonCorrCoef(Metric): + r"""Compute `Weighted Pearson Correlation Coefficient`_. + + .. math:: + P_{corr}(x,y;w) = \frac{cov(x,y;w)}{cov(x,x;w) cov(y,y;w)}, + + where :math:`cov(x,y;w)` is the weighted covariance, + :math:`cov(x,x;w)` is the weighted variance of :math:`x`, + :math:`cov(y,y;w)` is the weighted variance of :math:`y`, + :math:`y` is a tensor of target values, + :math:`x` is a tensor of predictions, + and :math:`w` is a tensor of weights. + + As input to ``forward`` and ``update`` the metric accepts the following input: + + - ``preds`` (:class:`~torch.Tensor`): either single output float tensor with shape ``(N,)`` + or multioutput float tensor of shape ``(N,d)`` + - ``target`` (:class:`~torch.Tensor`): either single output tensor with shape ``(N,)`` + or multioutput tensor of shape ``(N,d)`` + - ``weights`` (:class:`~torch.Tensor`): single tensor with shape ``(N,)`` + + As output of ``forward`` and ``compute`` the metric returns the following output: + + - ``pearson`` (:class:`~torch.Tensor`): A tensor with the weighted Pearson Correlation Coefficient + + Args: + num_outputs: Number of outputs in multioutput setting + kwargs: Additional keyword arguments, see :ref:`Metric kwargs` for more info. + + Example (single output weighted regression): + >>> from torchmetrics.regression import WeightedPearsonCorrCoef + >>> target = torch.tensor([3, -0.5, 2, 7]) + >>> preds = torch.tensor([2.5, 0.0, 2, 8]) + >>> weights = torch.tensor([0.1, 0.2, 0.5, 0.1]) + >>> pearson = WeightedPearsonCorrCoef() + >>> pearson(preds, target, weights) + tensor(0.9837) + + Example (multi output weighted regression): + >>> from torchmetrics.regression import WeightedPearsonCorrCoef + >>> target = torch.tensor([[3, -0.5], [2, 7], [-1, 1.5]]) + >>> preds = torch.tensor([[2.5, 0.0], [2, 8], [0.0, 1.3]]) + >>> weights = torch.tensor([0.3, 0.2, 0.5]) + >>> pearson = WeightedPearsonCorrCoef(num_outputs=2) + >>> pearson(preds, target, weights) + tensor([0.9992, 0.9902]) + + """ + + is_differentiable: bool = True + higher_is_better: Optional[bool] = None # both -1 and 1 are optimal + full_state_update: bool = True + plot_lower_bound: float = -1.0 + plot_upper_bound: float = 1.0 + preds: List[Tensor] + target: List[Tensor] + mean_x: Tensor + mean_y: Tensor + var_x: Tensor + var_y: Tensor + cov_xy: Tensor + weights_sum: Tensor + + def __init__( + self, + num_outputs: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + if num_outputs < 1: + raise ValueError("Expected argument `num_outputs` to be an `int` larger than 0, but got {num_outputs}.") + self.num_outputs = num_outputs + + self.add_state("mean_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("mean_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_x", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("var_y", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("cov_xy", default=torch.zeros(self.num_outputs), dist_reduce_fx=None) + self.add_state("weights_sum", default=torch.tensor(0.0), dist_reduce_fx=None) + + def update(self, preds: Tensor, target: Tensor, weights: Tensor) -> None: + """Update state with predictions and targets.""" + self.mean_x, self.mean_y, self.var_x, self.var_y, self.cov_xy, self.weights_sum = ( + _weighted_pearson_corrcoef_update( + preds, + target, + weights, + self.mean_x, + self.mean_y, + self.var_x, + self.var_y, + self.cov_xy, + self.weights_sum, + self.num_outputs, + ) + ) + + def compute(self) -> Tensor: + """Compute weighted Pearson correlation coefficient over state.""" + if (self.num_outputs == 1 and self.mean_x.numel() > 1) or (self.num_outputs > 1 and self.mean_x.ndim > 1): + # multiple devices, need further reduction + _, _, var_x, var_y, cov_xy, weights_sum = _final_aggregation( + self.mean_x, self.mean_y, self.var_x, self.var_y, self.cov_xy, self.weights_sum + ) + else: + var_x = self.var_x + var_y = self.var_y + cov_xy = self.cov_xy + weights_sum = self.weights_sum + + return _weighted_pearson_corrcoef_compute(var_x, var_y, cov_xy, weights_sum) + + def plot( + self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None + ) -> _PLOT_OUT_TYPE: + """Plot a single or multiple values from the metric. + + Args: + val: Either a single result from calling `metric.forward` or `metric.compute` or a list of these results. + If no value is provided, will automatically call `metric.compute` and plot that result. + ax: An matplotlib axis object. If provided will add plot to that axis + + Returns: + Figure and Axes object + + Raises: + ModuleNotFoundError: + If `matplotlib` is not installed + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting a single value + >>> from torchmetrics.regression import WeightedPearsonCorrCoef + >>> metric = WeightedPearsonCorrCoef() + >>> metric.update(randn(10,), randn(10,), randn(10,)) + >>> fig_, ax_ = metric.plot() + + .. plot:: + :scale: 75 + + >>> from torch import randn + >>> # Example plotting multiple values + >>> from torchmetrics.regression import WeightedPearsonCorrCoef + >>> metric = WeightedPearsonCorrCoef() + >>> values = [] + >>> for _ in range(10): + ... values.append(metric(randn(10,), randn(10,), randn(10,))) + >>> fig, ax = metric.plot(values) + + """ + return self._plot(val, ax) diff --git a/tests/unittests/_helpers/testers.py b/tests/unittests/_helpers/testers.py index 4b793e008b7..0dfd1ad330a 100644 --- a/tests/unittests/_helpers/testers.py +++ b/tests/unittests/_helpers/testers.py @@ -266,7 +266,7 @@ def _class_test( elif check_batch and not metric.dist_sync_on_step: batch_kwargs_update = { - k: v.cpu() if isinstance(v, Tensor) else v + k: v[i].cpu() if isinstance(v, Tensor) else v for k, v in (batch_kwargs_update if fragment_kwargs else kwargs_update).items() } preds_ = preds[i].cpu() if isinstance(preds, Tensor) else preds[i] @@ -346,8 +346,10 @@ def _functional_test( atol: absolute tolerance used for comparison of results device: determine which device to run on, either 'cuda' or 'cpu' fragment_kwargs: whether tensors in kwargs should be divided as `preds` and `target` among processes - kwargs_update: Additional keyword arguments that will be passed with preds and - target when running update on the metric. + kwargs_update: Additional keyword arguments that will be passed with `preds` and + `target` when running update on the metric. If values are torch.Tensor objects, tests + will iterate over the first dimension of the Tensor with each batch. + Otherwise, the same value will be used for all batches. """ p_size = preds.shape[0] if isinstance(preds, Tensor) else len(preds) @@ -375,7 +377,7 @@ def _functional_test( extra_kwargs = {k: v[i] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} tm_result = metric(preds[i], target[i], **extra_kwargs) extra_kwargs = { - k: v.cpu() if isinstance(v, Tensor) else v + k: v[i].cpu() if isinstance(v, Tensor) else v for k, v in (extra_kwargs if fragment_kwargs else kwargs_update).items() } ref_result = _reference_cachier(reference_metric)( @@ -641,6 +643,7 @@ def run_differentiability_test( metric_module: Metric, metric_functional: Optional[Callable] = None, metric_args: Optional[dict] = None, + **kwargs_update: Any, ) -> None: """Test if a metric is differentiable or not. @@ -650,14 +653,17 @@ def run_differentiability_test( metric_module: the metric module to test metric_functional: functional version of the metric metric_args: dict with additional arguments used for class initialization + kwargs_update: dict with additional arguments for metric update """ metric_args = metric_args or {} + # only floating point tensors can require grad metric = metric_module(**metric_args) if preds.is_floating_point(): preds.requires_grad = True - out = metric(preds[0, :2], target[0, :2]) + kwargs_update = {k: v[0, :4] if isinstance(v, Tensor) else v for k, v in kwargs_update.items()} + out = metric(preds[0, :4], target[0, :4], **kwargs_update) # Check if requires_grad matches is_differentiable attribute _assert_requires_grad(metric, out) @@ -665,7 +671,7 @@ def run_differentiability_test( if metric.is_differentiable and metric_functional is not None: # check for numerical correctness assert torch.autograd.gradcheck( - partial(metric_functional, **metric_args), (preds[0, :2].double(), target[0, :2]) + partial(metric_functional, **kwargs_update), (preds[0, :4].double(), target[0, :4]) ) # reset as else it will carry over to other tests diff --git a/tests/unittests/regression/test_pearson.py b/tests/unittests/regression/test_pearson.py index f218fad9ce9..0f437680fa6 100644 --- a/tests/unittests/regression/test_pearson.py +++ b/tests/unittests/regression/test_pearson.py @@ -13,12 +13,15 @@ # limitations under the License. from functools import partial +import numpy as np import pytest import torch from scipy.stats import pearsonr from torchmetrics.functional.regression.pearson import pearson_corrcoef +from torchmetrics.functional.regression.weighted_pearson import weighted_pearson_corrcoef from torchmetrics.regression.pearson import PearsonCorrCoef, _final_aggregation +from torchmetrics.regression.weighted_pearson import WeightedPearsonCorrCoef from torchmetrics.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from unittests import BATCH_SIZE, EXTRA_DIM, NUM_BATCHES, _Input from unittests._helpers import seed_all @@ -37,7 +40,6 @@ target=torch.randn(NUM_BATCHES, BATCH_SIZE), ) - _multi_target_inputs1 = _Input( preds=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), target=torch.rand(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), @@ -48,13 +50,28 @@ target=torch.randn(NUM_BATCHES, BATCH_SIZE, EXTRA_DIM), ) +_weights = torch.rand(NUM_BATCHES, BATCH_SIZE) + def _reference_scipy_pearson(preds, target): if preds.ndim == 2: - return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T)] + return [pearsonr(t.numpy(), p.numpy())[0] for t, p in zip(target.T, preds.T, strict=True)] return pearsonr(target.numpy(), preds.numpy())[0] +def _reference_weighted_pearson(preds, target, weights): + if preds.ndim == 2: + return [_reference_weighted_pearson(p, t, weights) for p, t in zip(preds.T, target.T, strict=True)] + + preds, target, weights = preds.numpy(), target.numpy(), weights.numpy() + mx = (weights * preds).sum() / weights.sum() + my = (weights * target).sum() / weights.sum() + var_x = (weights * (preds - mx) ** 2).sum() + var_y = (weights * (target - my) ** 2).sum() + cov_xy = (weights * (preds - mx) * (target - my)).sum() + return cov_xy / np.sqrt(var_x * var_y) + + @pytest.mark.parametrize( ("preds", "target"), [ @@ -112,25 +129,140 @@ def test_pearson_corrcoef_half_gpu(self, preds, target): self.run_precision_test_gpu(preds, target, partial(PearsonCorrCoef, num_outputs=num_outputs), pearson_corrcoef) -def test_error_on_different_shape(): +@pytest.mark.parametrize( + ("preds", "target", "weights"), + [ + (_single_target_inputs1.preds, _single_target_inputs1.target, _weights), + (_single_target_inputs2.preds, _single_target_inputs2.target, _weights), + (_multi_target_inputs1.preds, _multi_target_inputs1.target, _weights), + (_multi_target_inputs2.preds, _multi_target_inputs2.target, _weights), + ], +) +class TestWeightedPearsonCorrCoef(MetricTester): + """Test class for `WeightedPearsonCorrCoef` metric.""" + + atol = 1e-3 + + @pytest.mark.parametrize("compute_on_cpu", [True, False]) + @pytest.mark.parametrize("ddp", [pytest.param(True, marks=pytest.mark.DDP), False]) + def test_weighted_pearson_corrcoef(self, preds, target, weights, compute_on_cpu, ddp): + """Test class implementation of metric.""" + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_class_metric_test( + ddp=ddp, + preds=preds, + target=target, + metric_class=WeightedPearsonCorrCoef, + reference_metric=_reference_weighted_pearson, + metric_args={"num_outputs": num_outputs, "compute_on_cpu": compute_on_cpu}, + weights=weights, + ) + + def test_weighted_pearson_corrcoef_functional(self, preds, target, weights): + """Test functional implementation of metric.""" + self.run_functional_metric_test( + preds=preds, + target=target, + metric_functional=weighted_pearson_corrcoef, + reference_metric=_reference_weighted_pearson, + weights=weights, + ) + + def test_weighted_pearson_corrcoef_differentiability(self, preds, target, weights): + """Test the differentiability of the metric, according to its `is_differentiable` attribute.""" + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_differentiability_test( + preds=preds, + target=target, + metric_module=partial(WeightedPearsonCorrCoef, num_outputs=num_outputs), + metric_functional=weighted_pearson_corrcoef, + weights=weights, + ) + + @pytest.mark.skipif(not _TORCH_GREATER_EQUAL_2_5, reason="Requires torch>=2.5.0") + def test_weighted_pearson_corrcoef_half_cpu(self, preds, target, weights): + """Test dtype support of the metric on CPU.""" + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_cpu( + preds, + target, + partial(WeightedPearsonCorrCoef, num_outputs=num_outputs), + weighted_pearson_corrcoef, + weights=weights, + ) + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires cuda") + def test_weighted_pearson_corrcoef_half_gpu(self, preds, target, weights): + """Test dtype support of the metric on GPU.""" + num_outputs = EXTRA_DIM if preds.ndim == 3 else 1 + self.run_precision_test_gpu( + preds, + target, + partial(WeightedPearsonCorrCoef, num_outputs=num_outputs), + weighted_pearson_corrcoef, + weights=weights, + ) + + +@pytest.mark.parametrize( + ("metric_class", "metric_args"), + [ + (PearsonCorrCoef, [torch.randn(100), torch.randn(50)]), + (WeightedPearsonCorrCoef, [torch.randn(100), torch.randn(50), torch.randn(50)]), + ], +) +def test_error_on_different_shape(metric_class, metric_args): """Test that error is raised on different shapes of input.""" - metric = PearsonCorrCoef(num_outputs=1) + metric = metric_class(num_outputs=1) with pytest.raises(RuntimeError, match="Predictions and targets are expected to have the same shape"): - metric(torch.randn(100), torch.randn(50)) + metric(*metric_args) - metric = PearsonCorrCoef(num_outputs=5) + +@pytest.mark.parametrize( + ("metric_class", "metric_args"), + [ + (PearsonCorrCoef, [torch.randn(100, 2, 5), torch.randn(100, 2, 5)]), + (WeightedPearsonCorrCoef, [torch.randn(100, 2, 5), torch.randn(100, 2, 5), torch.randn(100)]), + ], +) +def test_error_on_invalid_ndim(metric_class, metric_args): + """Test that error is raised on invalid dimensions.""" + metric = metric_class(num_outputs=5) with pytest.raises(ValueError, match="Expected both predictions and target to be either 1- or 2-.*"): - metric(torch.randn(100, 2, 5), torch.randn(100, 2, 5)) + metric(*metric_args) + - metric = PearsonCorrCoef(num_outputs=2) +@pytest.mark.parametrize( + ("metric_class", "metric_args"), + [ + (PearsonCorrCoef, [torch.randn(100, 3), torch.randn(100, 3)]), + (WeightedPearsonCorrCoef, [torch.randn(100, 3), torch.randn(100, 3), torch.randn(100)]), + ], +) +def test_error_on_num_outputs_mismatch(metric_class, metric_args): + """Test that error is raised if `num_outputs` of `preds` or `target` do not match initialization.""" + metric = metric_class(num_outputs=2) with pytest.raises(ValueError, match="Expected argument `num_outputs` to match the second dimension of input.*"): - metric(torch.randn(100, 5), torch.randn(100, 5)) + metric(*metric_args) -def test_1d_input_allowed(): +@pytest.mark.parametrize( + ("metric_functional", "metric_args"), + [ + (pearson_corrcoef, [[torch.randn(10, 1), torch.randn(10, 1)], [torch.randn(10), torch.randn(10)]]), + ( + weighted_pearson_corrcoef, + [ + [torch.randn(10, 1), torch.randn(10, 1), torch.randn(10)], + [torch.randn(10), torch.randn(10), torch.randn(10)], + ], + ), + ], +) +def test_1d_input_allowed(metric_functional, metric_args): """Check that both input of the form [N,] and [N,1] is allowed with default num_outputs argument.""" - assert isinstance(pearson_corrcoef(torch.randn(10, 1), torch.randn(10, 1)), torch.Tensor) - assert isinstance(pearson_corrcoef(torch.randn(10), torch.randn(10)), torch.Tensor) + assert isinstance(metric_functional(*metric_args[0]), torch.Tensor) + assert isinstance(metric_functional(*metric_args[1]), torch.Tensor) @pytest.mark.parametrize("shapes", [(5,), (1, 5), (2, 5)])