Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions docs/source/pages/lightning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,40 @@ Common Pitfalls

The following contains a list of pitfalls to be aware of:

* `self.log` only supports logging of scalar tensors, and therefore also only support logging metrics that return scalar
tensors. The vast majority of metrics in TorchMetrics return a scalar tensor, some metrics however return non-scalar
tensors (often dictionaries or lists of tensors) and should therefore be logged manually. One recommended pattern is
to utilize that metrics support *Metric Arithmetic* which includes indexing and slicing. This combined with Metric
Collection can be used to unpack and repack the metric outputs into a format that is compatible with ``self.log`` and
``self.log_dict``. Example:

.. testcode:: python

class MyModule(LightningModule):

def __init__(self, num_classes):
# initialize a nons-scalar returning metric, in this case we use MeanAveragePrecision
# this metric returns a dict, each being a scalar tensor, assume we are only interested in a few of them
map = torchmetrics.detection.MeanAveragePrecision()

# index into the metric the values we want to log and repack into a MetricCollection
# MetricCollection compute group feature will make sure that calculations are only done once
self.metrics = torchmetrics.MetricCollection(
{
"map_50": map["map_50"],
"map_75": map["map_75"],
"map_small": map["map_small"],
},
prefix="val_",
)

def validation_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
self.metrics.update(preds, y)
self.log_dict(self.metrics, on_step=False, on_epoch=True)

* Logging a `MetricCollection` object directly using ``self.log_dict`` is only supported if all metrics in the
collection return a scalar tensor. If any of the metrics in the collection return a non-scalar tensor,
the logging will fail. This can especially happen when either nesting multiple ``MetricCollection`` objects or when
Expand Down
13 changes: 7 additions & 6 deletions src/torchmetrics/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,6 @@ def _merge_compute_groups(self) -> None:

metric1 = getattr(self, cg_members1[0])
metric2 = getattr(self, cg_members2[0])

if self._equal_metric_states(metric1, metric2):
self._groups[cg_idx1].extend(self._groups.pop(cg_idx2))
break
Expand All @@ -306,15 +305,17 @@ def _merge_compute_groups(self) -> None:
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
"""Check if the metric state of two metrics are the same."""
# empty state
if len(metric1._defaults) == 0 or len(metric2._defaults) == 0:
metric1_state = metric1.metric_state
metric2_state = metric2.metric_state
if len(metric1_state) == 0 or len(metric2_state) == 0:
return False

if metric1._defaults.keys() != metric2._defaults.keys():
if metric1_state.keys() != metric2_state.keys():
return False

for key in metric1._defaults:
state1 = getattr(metric1, key)
state2 = getattr(metric2, key)
for key in metric1_state:
state1 = metric1_state[key]
state2 = metric2_state[key]

if type(state1) != type(state2): # noqa: E721
return False
Expand Down
10 changes: 10 additions & 0 deletions src/torchmetrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,16 @@ def __init__(
else:
self.metric_b = metric_b

@property
def metric_state(self) -> dict[str, Union[List[Tensor], Tensor]]:
"""Return the metric state of the compositional metric."""
state = {}
if isinstance(self.metric_a, Metric):
state.update(self.metric_a.metric_state)
if isinstance(self.metric_b, Metric):
state.update(self.metric_b.metric_state)
return state

def _sync_dist(self, dist_sync_fn: Optional[Callable] = None, process_group: Optional[Any] = None) -> None:
"""No syncing required here.

Expand Down
55 changes: 55 additions & 0 deletions tests/unittests/bases/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
MultilabelAUROC,
MultilabelAveragePrecision,
)
from torchmetrics.detection import MeanAveragePrecision
from torchmetrics.regression import PearsonCorrCoef
from torchmetrics.text import BLEUScore
from torchmetrics.utilities.checks import _allclose_recursive
Expand Down Expand Up @@ -855,3 +856,57 @@ def test_collection_state_being_re_established_after_copy():
assert not m12._state_is_copy
assert m12.m1.mean_x.data_ptr() == m12.m2.mean_x.data_ptr(), "States should point to the same location"
assert m12._equal_metric_states(m12.m1, m12.m2)


def test_indexed_metric_in_collection():
"""Test that a metric can be indexed and recombined in a MetricCollection."""
preds = [
{
"boxes": torch.tensor([[258.0, 41.0, 606.0, 285.0]]),
"scores": torch.tensor([0.536]),
"labels": torch.tensor([0]),
}
]
target = [
{
"boxes": torch.tensor([[214.0, 41.0, 562.0, 285.0]]),
"labels": torch.tensor([0]),
}
]
metric = MeanAveragePrecision(iou_type="bbox")
metric1 = metric["map"]
metric2 = metric["map_50"]
state1 = metric1.metric_state
state2 = metric2.metric_state
assert isinstance(state1, dict)
assert isinstance(state2, dict)

# Create a collection with the indexed metrics
collection = MetricCollection({"mAP": metric1, "mAP_50": metric2})
collection.update(preds, target)

# Compute and verify results
results = collection.compute()
assert "mAP" in results
assert "mAP_50" in results
assert isinstance(results["mAP"], torch.Tensor)
assert isinstance(results["mAP_50"], torch.Tensor)

assert len(collection.compute_groups) == 1, (
f"Expected 1 compute group for indexed metrics from same base, got {len(collection.compute_groups)}"
)
assert set(collection.compute_groups[0]) == {"mAP", "mAP_50"}, (
"Both indexed metrics should be in the same compute group"
)
assert metric1.metric_a is metric2.metric_a, "Both indexed metrics should share the same base metric"

# Check that the states are equal
for key in state1:
if key in state2:
if isinstance(state1[key], list):
assert len(state1[key]) == len(state2[key]), f"State list length mismatch for key {key}"
for s1, s2 in zip(state1[key], state2[key]):
if isinstance(s1, torch.Tensor) and isinstance(s2, torch.Tensor):
assert torch.equal(s1, s2), f"State mismatch for key {key}"
elif isinstance(state1[key], torch.Tensor) and isinstance(state2[key], torch.Tensor):
assert torch.equal(state1[key], state2[key]), f"State mismatch for key {key}"
45 changes: 45 additions & 0 deletions tests/unittests/bases/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,48 @@ def test_compositional_metrics_update():

assert compos.metric_a._num_updates == 3
assert compos.metric_b._num_updates == 3


def test_compositional_metric_state_property():
"""Test that metric_state property works correctly for CompositionalMetric."""
metric_a = DummyMetric(5)
metric_b = DummyMetric(3)
compos_binary = metric_a + metric_b

assert isinstance(compos_binary, CompositionalMetric)
state = compos_binary.metric_state
assert isinstance(state, dict)
compos_binary.update()
state_after_update = compos_binary.metric_state
assert torch.equal(state_after_update["_num_updates"], tensor(1))

metric_c = DummyMetric(10)
compos_unary = abs(metric_c)

assert isinstance(compos_unary, CompositionalMetric)
assert compos_unary.metric_b is None
state_unary = compos_unary.metric_state
assert isinstance(state_unary, dict)
assert "_num_updates" in state_unary
compos_unary.update()
state_unary_after_update = compos_unary.metric_state
assert torch.equal(state_unary_after_update["_num_updates"], tensor(1))

metric_d = DummyMetric([1, 2, 3])
compos_getitem = metric_d[1]
assert isinstance(compos_getitem, CompositionalMetric)
assert compos_getitem.metric_b is None
state_getitem = compos_getitem.metric_state
assert isinstance(state_getitem, dict)
assert "_num_updates" in state_getitem

metric_e = DummyMetric(5)
compos_scalar = metric_e + 10
assert isinstance(compos_scalar, CompositionalMetric)
assert not isinstance(compos_scalar.metric_b, Metric)
state_scalar = compos_scalar.metric_state
assert isinstance(state_scalar, dict)
assert "_num_updates" in state_scalar
compos_scalar.update()
state_scalar_after = compos_scalar.metric_state
assert torch.equal(state_scalar_after["_num_updates"], tensor(1))
Loading