From 7e240cdf454ec9207b8e60a0e1cfb5cfdc0d93b3 Mon Sep 17 00:00:00 2001 From: Isalia20 Date: Mon, 7 Apr 2025 00:46:06 +0400 Subject: [PATCH 1/4] dice score add warnings --- src/torchmetrics/functional/segmentation/dice.py | 8 ++++++++ src/torchmetrics/segmentation/dice.py | 8 ++++++++ 2 files changed, 16 insertions(+) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 029f5b25dd1..16235efc4b4 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -18,6 +18,7 @@ from typing_extensions import Literal from torchmetrics.functional.segmentation.utils import _ignore_background +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.compute import _safe_divide @@ -154,6 +155,13 @@ def dice_score( [0.1978, 0.2804, 0.1714, 0.1915, 0.2783]]) """ + if average == "micro": + rank_zero_warn( + "dice_score metric currently defaults to `average=micro`, but will change to" + "`average=macro` in the next release." + " If you've explicitly set this parameter, you can ignore this warning.", + UserWarning, + ) _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) return _dice_score_compute(numerator, denominator, average, support=support) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index 54f2eedf0cf..cf1e7bdbe14 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -23,6 +23,7 @@ _dice_score_validate_args, ) from torchmetrics.metric import Metric +from torchmetrics.utilities import rank_zero_warn from torchmetrics.utilities.data import dim_zero_cat from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE @@ -116,6 +117,13 @@ def __init__( **kwargs: Any, ) -> None: super().__init__(**kwargs) + if average == "micro": + rank_zero_warn( + "DiceScore metric currently defaults to `average=micro`, but will change to" + "`average=macro` in the next release." + " If you've explicitly set this parameter, you can ignore this warning.", + UserWarning, + ) _dice_score_validate_args(num_classes, include_background, average, input_format, zero_division) self.num_classes = num_classes self.include_background = include_background From 282fb635ee8c58dce1725536ca5394d46c2b671f Mon Sep 17 00:00:00 2001 From: Jirka Borovec <6035284+Borda@users.noreply.github.com> Date: Mon, 7 Apr 2025 11:50:19 +0200 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Nicki Skafte Detlefsen --- src/torchmetrics/functional/segmentation/dice.py | 4 ++-- src/torchmetrics/segmentation/dice.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index 16235efc4b4..c2071638e45 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -158,9 +158,9 @@ def dice_score( if average == "micro": rank_zero_warn( "dice_score metric currently defaults to `average=micro`, but will change to" - "`average=macro` in the next release." + "`average=macro` in the v1.9 release." " If you've explicitly set this parameter, you can ignore this warning.", - UserWarning, + DeprecationWarning, ) _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index cf1e7bdbe14..af2c4210709 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -120,9 +120,9 @@ def __init__( if average == "micro": rank_zero_warn( "DiceScore metric currently defaults to `average=micro`, but will change to" - "`average=macro` in the next release." + "`average=macro` in the v1.9 release." " If you've explicitly set this parameter, you can ignore this warning.", - UserWarning, + DeprecationWarning, ) _dice_score_validate_args(num_classes, include_background, average, input_format, zero_division) self.num_classes = num_classes From 5f4e46cf7bffc12d44383d20eb7d8016ca8176df Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 7 Apr 2025 16:02:26 +0200 Subject: [PATCH 3/4] Update src/torchmetrics/segmentation/dice.py --- src/torchmetrics/segmentation/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/segmentation/dice.py b/src/torchmetrics/segmentation/dice.py index af2c4210709..174812e6a1f 100644 --- a/src/torchmetrics/segmentation/dice.py +++ b/src/torchmetrics/segmentation/dice.py @@ -122,7 +122,7 @@ def __init__( "DiceScore metric currently defaults to `average=micro`, but will change to" "`average=macro` in the v1.9 release." " If you've explicitly set this parameter, you can ignore this warning.", - DeprecationWarning, + UserWarning, ) _dice_score_validate_args(num_classes, include_background, average, input_format, zero_division) self.num_classes = num_classes From 5351dbe093ab80321f83f5e8ad1d147cc3cc6c1f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Mon, 7 Apr 2025 16:02:31 +0200 Subject: [PATCH 4/4] Update src/torchmetrics/functional/segmentation/dice.py --- src/torchmetrics/functional/segmentation/dice.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/segmentation/dice.py b/src/torchmetrics/functional/segmentation/dice.py index c2071638e45..715969cefee 100644 --- a/src/torchmetrics/functional/segmentation/dice.py +++ b/src/torchmetrics/functional/segmentation/dice.py @@ -160,7 +160,7 @@ def dice_score( "dice_score metric currently defaults to `average=micro`, but will change to" "`average=macro` in the v1.9 release." " If you've explicitly set this parameter, you can ignore this warning.", - DeprecationWarning, + UserWarning, ) _dice_score_validate_args(num_classes, include_background, average, input_format) numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)