Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/torchmetrics/functional/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing_extensions import Literal

from torchmetrics.functional.segmentation.utils import _ignore_background
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.compute import _safe_divide

Expand Down Expand Up @@ -154,6 +155,13 @@ def dice_score(
[0.1978, 0.2804, 0.1714, 0.1915, 0.2783]])

"""
if average == "micro":
rank_zero_warn(
"dice_score metric currently defaults to `average=micro`, but will change to"
"`average=macro` in the v1.9 release."
" If you've explicitly set this parameter, you can ignore this warning.",
DeprecationWarning,
)
_dice_score_validate_args(num_classes, include_background, average, input_format)
numerator, denominator, support = _dice_score_update(preds, target, num_classes, include_background, input_format)
return _dice_score_compute(numerator, denominator, average, support=support)
8 changes: 8 additions & 0 deletions src/torchmetrics/segmentation/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_dice_score_validate_args,
)
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat
from torchmetrics.utilities.imports import _MATPLOTLIB_AVAILABLE
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE
Expand Down Expand Up @@ -116,6 +117,13 @@ def __init__(
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
if average == "micro":
rank_zero_warn(
"DiceScore metric currently defaults to `average=micro`, but will change to"
"`average=macro` in the v1.9 release."
" If you've explicitly set this parameter, you can ignore this warning.",
DeprecationWarning,
)
_dice_score_validate_args(num_classes, include_background, average, input_format, zero_division)
self.num_classes = num_classes
self.include_background = include_background
Expand Down
Loading