Skip to content

MulticlassPrecisionRecallCurve with thresholds=[integer] does not use constant memory on MPS #3299

@Snailed

Description

@Snailed

🐛 Bug

According to the MulticlassPrecisionRecallCurve documentation, setting the thresholds argument should cause MulticlassPrecisionRecallCurve to use a constant amount of memory.
When the metric is on MPS, this is not the case.

To Reproduce

Sample code and steps to reproduce the behavior with expected result...

Code sample
# Correct behavior
metric = torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes=3, ignore_index=-1, thresholds=200).to('cpu')
pred = torch.randn(60000, 3).to('cpu')
target = torch.randint(0, 3, (60000,)).to('cpu')
metric = metric(pred, target)
print(metric)


# Crashes with OOM error
metric = torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes=3, ignore_index=-1, thresholds=200).to('mps')
pred = torch.randn(60000, 3).to('mps')
target = torch.randint(0, 3, (60000,)).to('mps')
metric = metric(pred, target)
print(metric)
Environment
  • TorchMetrics version (if build from source, add commit SHA): 1.8.1
  • Python & PyTorch Version (e.g., 1.0): 3.11.7
  • Any other relevant information such as OS (e.g., Linux): MacOS Tahoe 26.0.1 on Macbook Pro M4 Max

Additional context

Here's the error I get when I run the code on MPS. I can see that it does not use constant memory, because it attempts to allocate different amounts of memory when I vary the input size.

Traceback
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[15], [line 5](vscode-notebook-cell:?execution_count=15&line=5)
      3 target = torch.randint(0, 3, (60000,)).to('mps')
      4 print( pred.shape, target.shape)
----> [5](vscode-notebook-cell:?execution_count=15&line=5) metric = metric(pred, target)
      6 metric

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
   1771     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773)     return self._call_impl(*args, **kwargs)

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
   1779 # If we don't have any hooks, we want to skip the rest of the logic in
   1780 # this function, and just call forward.
   1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1782         or _global_backward_pre_hooks or _global_backward_hooks
   1783         or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784)     return forward_call(*args, **kwargs)
   1786 result = None
   1787 called_always_called_hooks = set()

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:315, in Metric.forward(self, *args, **kwargs)
    313     self._forward_cache = self._forward_full_state_update(*args, **kwargs)
    314 else:
--> [315](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:315)     self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
    317 return self._forward_cache

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:384, in Metric._forward_reduce_state_update(self, *args, **kwargs)
    381 self._enable_grad = True  # allow grads for batch computation
    383 # calculate batch state and compute batch value
--> [384](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:384) self.update(*args, **kwargs)
    385 batch_val = self.compute()
    387 # reduce batch and global state

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:559, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
    551         if "Expected all tensors to be on" in str(err):
    552             raise RuntimeError(
    553                 "Encountered different devices in metric calculation (see stacktrace for details)."
    554                 " This could be due to the metric class not being on the same device as input."
   (...)    557                 " device corresponds to the device of the input."
    558             ) from err
--> [559](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:559)         raise err
    561 if self.compute_on_cpu:
    562     self._move_list_states_to_cpu()

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:549, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
    547 with torch.set_grad_enabled(self._enable_grad):
    548     try:
--> [549](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:549)         update(*args, **kwargs)
    550     except RuntimeError as err:
    551         if "Expected all tensors to be on" in str(err):

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/classification/precision_recall_curve.py:379, in MulticlassPrecisionRecallCurve.update(self, preds, target)
    375     _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
    376 preds, target, _ = _multiclass_precision_recall_curve_format(
    377     preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average
    378 )
--> [379](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/classification/precision_recall_curve.py:379) state = _multiclass_precision_recall_curve_update(
    380     preds, target, self.num_classes, self.thresholds, self.average
    381 )
    382 if isinstance(state, Tensor):
    383     self.confmat += state

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:486, in _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average)
    484 else:
    485     update_fn = _multiclass_precision_recall_curve_update_loop
--> [486](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:486) return update_fn(preds, target, num_classes, thresholds)

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:507, in _multiclass_precision_recall_curve_update_vectorized(preds, target, num_classes, thresholds)
    505 unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
    506 unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
--> [507](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:507) bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
    508 return bins.reshape(len_t, num_classes, 2, 2)

File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/utilities/data.py:203, in _bincount(x, minlength)
    200     minlength = len(torch.unique(x))
    202 if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
--> [203](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/utilities/data.py:203)     mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
    204     return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)
    206 return torch.bincount(x, minlength=minlength)

RuntimeError: Invalid buffer size: 643.73 GiB

Metadata

Metadata

Assignees

No one assigned

    Labels

    bug / fixSomething isn't workinghelp wantedExtra attention is needed

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions