Skip to content

Commit 4fb978b

Browse files
committed
feat(metrics): Add MAPEMetric for regression evaluation.
Signed-off-by: Akshat Sinha <[email protected]>
1 parent 57fdd59 commit 4fb978b

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .panoptic_quality import PanopticQualityMetric, compute_panoptic_quality
2929
from .regression import (
3030
MAEMetric,
31+
MAPEMetric,
3132
MSEMetric,
3233
MultiScaleSSIMMetric,
3334
PSNRMetric,

monai/metrics/regression.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,38 @@ def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor
143143
return compute_mean_error_metrics(y_pred, y, func=self.abs_func)
144144

145145

146+
class MAPEMetric(RegressionMetric):
147+
r"""Compute Mean Absolute Percentage Error between two tensors using function:
148+
149+
.. math::
150+
\operatorname {MAPE}\left(Y, \hat{Y}\right) =\frac {100}{n}\sum _{i=1}^{n}\left|\frac{y_i-\hat{y_i}}{y_i}\right|.
151+
152+
More info: https://en.wikipedia.org/wiki/Mean_absolute_percentage_error
153+
154+
Input `y_pred` is compared with ground truth `y`.
155+
Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.
156+
Note: Tackling the undefined error, a tiny epsilon value is added to the denominator part.
157+
158+
Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.
159+
Args:
160+
reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
161+
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
162+
``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
163+
get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
164+
epsilonDefaults to 1e-7.
165+
166+
"""
167+
168+
def __init__(
169+
self, reduction: MetricReduction | str = MetricReduction.MEAN, get_not_nans: bool = False, epsilon: float = 1e-7
170+
) -> None:
171+
super().__init__(reduction=reduction, get_not_nans=get_not_nans)
172+
self.epsilon = epsilon
173+
174+
def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
175+
return compute_mape_metric(y_pred, y, epsilon=self.epsilon)
176+
177+
146178
class RMSEMetric(RegressionMetric):
147179
r"""Compute Root Mean Squared Error between two tensors using function:
148180
@@ -220,6 +252,23 @@ def compute_mean_error_metrics(y_pred: torch.Tensor, y: torch.Tensor, func: Call
220252
return torch.mean(flt(func(y - y_pred)), dim=-1, keepdim=True)
221253

222254

255+
def compute_mape_metric(y_pred: torch.Tensor, y: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:
256+
"""
257+
Compute Mean Absolute Percentage Error.
258+
259+
Args:
260+
y_pred: predicted values
261+
y: ground truth values
262+
epsilon: small value to avoid division by zero
263+
264+
Returns:
265+
MAPE value as percentage
266+
"""
267+
flt = partial(torch.flatten, start_dim=1)
268+
percentage_error = torch.abs((y - y_pred) / (torch.abs(y) + epsilon)) * 100.0
269+
return torch.mean(flt(percentage_error), dim=-1, keepdim=True)
270+
271+
223272
class KernelType(StrEnum):
224273
GAUSSIAN = "gaussian"
225274
UNIFORM = "uniform"

0 commit comments

Comments
 (0)