Skip to content

Commit 342bd7a

Browse files
theo-barfootpre-commit-ci[bot]ericspod
authored
Add CalibrationErrorMetric and CalibrationError handler (Project-MONAI#8707)
## Description Addresses Project-MONAI#8505 ### Overview This PR adds calibration error metrics and an Ignite handler to MONAI, enabling users to evaluate and monitor model calibration for segmentation and other multi-class probabilistic tasks with shape `(B, C, spatial...)`. ### What's Included #### 1. Calibration Metrics (`monai/metrics/calibration.py`) - **`calibration_binning()`**: Core function to compute calibration bins with mean predictions, mean ground truths, and bin counts. Exported to support research workflows where users need per-bin statistics for plotting reliability diagrams. - **`CalibrationReduction`**: Enum supporting three reduction methods: - `EXPECTED` - Expected Calibration Error (ECE): weighted average by bin count - `AVERAGE` - Average Calibration Error (ACE): simple average across bins - `MAXIMUM` - Maximum Calibration Error (MCE): maximum error across bins - **`CalibrationErrorMetric`**: A `CumulativeIterationMetric` subclass supporting: - Configurable number of bins - Background channel exclusion (`include_background`) - All standard MONAI metric reductions (`mean`, `sum`, `mean_batch`, etc.) - Batched, per-class computation #### 2. Ignite Handler (`monai/handlers/calibration.py`) - **`CalibrationError`**: An `IgniteMetricHandler` wrapper that: - Attaches to PyTorch Ignite engines for training/validation loops - Supports `save_details` for per-sample/per-channel metric details via the metric buffer - Integrates with MONAI's existing handler ecosystem #### 3. Comprehensive Tests - **`tests/metrics/test_calibration_metric.py`**: Tests covering: - Binning function correctness with NaN handling - ECE/ACE/MCE reduction modes - Background exclusion - Cumulative iteration behavior - Input validation (shape mismatch, ndim, num_bins) - **`tests/handlers/test_handler_calibration_error.py`**: Tests covering: - Handler attachment and computation via `engine.run()` - All calibration reduction modes - `save_details` functionality - Optional Ignite dependency handling (tests skip if Ignite not installed) ### Public API Exposes the following via `monai.metrics`: - `CalibrationErrorMetric` - `CalibrationReduction` - `calibration_binning` Exposes via `monai.handlers`: - `CalibrationError` ### Implementation Notes - Uses `scatter_add` + counts instead of `scatter_reduce("mean")` for better PyTorch version compatibility - Includes input validation with clear error messages - Clamps bin indices to prevent out-of-range errors with slightly out-of-bound probabilities - Uses `torch.nan_to_num` instead of in-place operations for cleaner code - Ignite is treated as an optional dependency in tests (skipped if not installed) ### Related Work The algorithmic approach follows the calibration metrics from [Average-Calibration-Losses](https://github.com/cai4cai/Average-Calibration-Losses/), with related publications: - [MICCAI 2024 Paper](https://papers.miccai.org/miccai-2024/091-Paper3075.html) - [arXiv Paper](https://arxiv.org/abs/2506.03942v1) ### Future Work As discussed in the issue, calibration losses will be added in a separate PR to keep changes focused and easier to review. ## Checklist - [x] Code follows MONAI style guidelines (ruff passes) - [x] All new code has appropriate license headers - [x] Public API is exported in `__init__.py` files - [x] Docstrings include examples with proper transforms usage - [x] Unit tests cover main functionality - [x] Tests handle optional Ignite dependency gracefully - [x] No breaking changes to existing API ## Example Usage ```python from monai.metrics import CalibrationErrorMetric from monai.transforms import Activations, AsDiscrete # Setup transforms softmax = Activations(softmax=True) to_onehot = AsDiscrete(to_onehot=num_classes) # Create metric metric = CalibrationErrorMetric( num_bins=15, include_background=False, calibration_reduction="expected" # ECE ) # In evaluation loop # Note: y_pred should be probabilities in [0,1], y should be one-hot/binarized for batch_data in dataloader: logits, labels = model(batch_data) preds = softmax(logits) labels_onehot = to_onehot(labels) metric(y_pred=preds, y=labels_onehot) ece = metric.aggregate() ``` ### With Ignite Handler ```python from monai.handlers import CalibrationError, from_engine calibration_handler = CalibrationError( num_bins=15, include_background=False, calibration_reduction="expected", output_transform=from_engine(["pred", "label"]), save_details=True, ) calibration_handler.attach(evaluator, name="calibration_error") ``` --------- Signed-off-by: Theo Barfoot <theo.barfoot@gmail.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com>
1 parent 4b1777f commit 342bd7a

File tree

8 files changed

+1096
-0
lines changed

8 files changed

+1096
-0
lines changed

docs/source/handlers.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ Panoptic Quality metrics handler
8383
:members:
8484

8585

86+
Calibration Error metrics handler
87+
---------------------------------
88+
.. autoclass:: CalibrationError
89+
:members:
90+
91+
8692
Mean squared error metrics handler
8793
----------------------------------
8894
.. autoclass:: MeanSquaredError

docs/source/metrics.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,15 @@ Metrics
185185
.. autoclass:: MetricsReloadedCategorical
186186
:members:
187187

188+
`Calibration Error`
189+
-------------------
190+
.. autofunction:: calibration_binning
191+
192+
.. autoclass:: CalibrationReduction
193+
:members:
194+
195+
.. autoclass:: CalibrationErrorMetric
196+
:members:
188197

189198

190199
Utilities

monai/handlers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
from .average_precision import AveragePrecision
15+
from .calibration import CalibrationError
1516
from .checkpoint_loader import CheckpointLoader
1617
from .checkpoint_saver import CheckpointSaver
1718
from .classification_saver import ClassificationSaver

monai/handlers/calibration.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from collections.abc import Callable
15+
16+
from monai.handlers.ignite_metric import IgniteMetricHandler
17+
from monai.metrics import CalibrationErrorMetric, CalibrationReduction
18+
from monai.utils import MetricReduction
19+
20+
__all__ = ["CalibrationError"]
21+
22+
23+
class CalibrationError(IgniteMetricHandler):
24+
"""
25+
Ignite handler to compute Calibration Error during training or evaluation.
26+
27+
**Why Calibration Matters:**
28+
29+
A well-calibrated model produces probability estimates that match the true likelihood of correctness.
30+
For example, predictions with 80% confidence should be correct approximately 80% of the time.
31+
Modern neural networks often exhibit poor calibration (typically overconfident), which can be
32+
problematic in medical imaging where probability estimates may inform clinical decisions.
33+
34+
This handler wraps :py:class:`~monai.metrics.CalibrationErrorMetric` for use with PyTorch Ignite
35+
engines, automatically computing and aggregating calibration errors across iterations.
36+
37+
**Supported Calibration Metrics:**
38+
39+
- **Expected Calibration Error (ECE)**: Weighted average of per-bin errors (most common).
40+
- **Average Calibration Error (ACE)**: Unweighted average across bins.
41+
- **Maximum Calibration Error (MCE)**: Worst-case calibration error.
42+
43+
Args:
44+
num_bins: Number of equally-spaced bins for calibration computation. Defaults to 20.
45+
include_background: Whether to include the first channel (index 0) in computation.
46+
Set to ``False`` to exclude background in segmentation tasks. Defaults to ``True``.
47+
calibration_reduction: Calibration error reduction mode. Options: ``"expected"`` (ECE),
48+
``"average"`` (ACE), ``"maximum"`` (MCE). Defaults to ``"expected"``.
49+
metric_reduction: Reduction across batch/channel after computing per-sample errors.
50+
Options: ``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
51+
``"mean_channel"``, ``"sum_channel"``. Defaults to ``"mean"``.
52+
output_transform: Callable to extract ``(y_pred, y)`` from ``engine.state.output``.
53+
See `Ignite concepts <https://pytorch.org/ignite/concepts.html#state>`_ and
54+
the batch output transform tutorial in the MONAI tutorials repository.
55+
save_details: If ``True``, saves per-sample/per-channel metric values to
56+
``engine.state.metric_details[name]``. Defaults to ``True``.
57+
58+
References:
59+
- Guo, C., et al. "On Calibration of Modern Neural Networks." ICML 2017.
60+
https://proceedings.mlr.press/v70/guo17a.html
61+
- Barfoot, T., et al. "Average Calibration Losses for Reliable Uncertainty in
62+
Medical Image Segmentation." arXiv:2506.03942v3, 2025.
63+
https://arxiv.org/abs/2506.03942v3
64+
65+
See Also:
66+
- :py:class:`~monai.metrics.CalibrationErrorMetric`: The underlying metric class.
67+
- :py:func:`~monai.metrics.calibration_binning`: Low-level binning for reliability diagrams.
68+
69+
Example:
70+
>>> from monai.handlers import CalibrationError, from_engine
71+
>>> from ignite.engine import Engine
72+
>>>
73+
>>> def evaluation_step(engine, batch):
74+
... # Returns dict with "pred" (probabilities) and "label" (one-hot)
75+
... return {"pred": model(batch["image"]), "label": batch["label"]}
76+
>>>
77+
>>> evaluator = Engine(evaluation_step)
78+
>>>
79+
>>> # Attach calibration error handler
80+
>>> CalibrationError(
81+
... num_bins=15,
82+
... include_background=False,
83+
... calibration_reduction="expected",
84+
... output_transform=from_engine(["pred", "label"]),
85+
... ).attach(evaluator, name="ECE")
86+
>>>
87+
>>> # After evaluation, access results
88+
>>> evaluator.run(val_loader)
89+
>>> ece = evaluator.state.metrics["ECE"]
90+
>>> print(f"Expected Calibration Error: {ece:.4f}")
91+
"""
92+
93+
def __init__(
94+
self,
95+
num_bins: int = 20,
96+
include_background: bool = True,
97+
calibration_reduction: CalibrationReduction | str = CalibrationReduction.EXPECTED,
98+
metric_reduction: MetricReduction | str = MetricReduction.MEAN,
99+
output_transform: Callable = lambda x: x,
100+
save_details: bool = True,
101+
) -> None:
102+
metric_fn = CalibrationErrorMetric(
103+
num_bins=num_bins,
104+
include_background=include_background,
105+
calibration_reduction=calibration_reduction,
106+
metric_reduction=metric_reduction,
107+
)
108+
109+
super().__init__(metric_fn=metric_fn, output_transform=output_transform, save_details=save_details)

monai/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from .active_learning_metrics import LabelQualityScore, VarianceMetric, compute_variance, label_quality_score
1515
from .average_precision import AveragePrecisionMetric, compute_average_precision
16+
from .calibration import CalibrationErrorMetric, CalibrationReduction, calibration_binning
1617
from .confusion_matrix import ConfusionMatrixMetric, compute_confusion_matrix_metric, get_confusion_matrix
1718
from .cumulative_average import CumulativeAverage
1819
from .f_beta_score import FBetaScore

0 commit comments

Comments
 (0)