-
Notifications
You must be signed in to change notification settings - Fork 475
Open
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed
Description
🐛 Bug
According to the MulticlassPrecisionRecallCurve documentation, setting the thresholds argument should cause MulticlassPrecisionRecallCurve to use a constant amount of memory.
When the metric is on MPS, this is not the case.
To Reproduce
Sample code and steps to reproduce the behavior with expected result...
Code sample
# Correct behavior
metric = torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes=3, ignore_index=-1, thresholds=200).to('cpu')
pred = torch.randn(60000, 3).to('cpu')
target = torch.randint(0, 3, (60000,)).to('cpu')
metric = metric(pred, target)
print(metric)
# Crashes with OOM error
metric = torchmetrics.classification.MulticlassPrecisionRecallCurve(num_classes=3, ignore_index=-1, thresholds=200).to('mps')
pred = torch.randn(60000, 3).to('mps')
target = torch.randint(0, 3, (60000,)).to('mps')
metric = metric(pred, target)
print(metric)Environment
- TorchMetrics version (if build from source, add commit SHA): 1.8.1
- Python & PyTorch Version (e.g., 1.0): 3.11.7
- Any other relevant information such as OS (e.g., Linux): MacOS Tahoe 26.0.1 on Macbook Pro M4 Max
Additional context
Here's the error I get when I run the code on MPS. I can see that it does not use constant memory, because it attempts to allocate different amounts of memory when I vary the input size.
Traceback
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[15], [line 5](vscode-notebook-cell:?execution_count=15&line=5)
3 target = torch.randint(0, 3, (60000,)).to('mps')
4 print( pred.shape, target.shape)
----> [5](vscode-notebook-cell:?execution_count=15&line=5) metric = metric(pred, target)
6 metric
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773, in Module._wrapped_call_impl(self, *args, **kwargs)
1771 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1772 else:
-> [1773](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1773) return self._call_impl(*args, **kwargs)
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784, in Module._call_impl(self, *args, **kwargs)
1779 # If we don't have any hooks, we want to skip the rest of the logic in
1780 # this function, and just call forward.
1781 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1782 or _global_backward_pre_hooks or _global_backward_hooks
1783 or _global_forward_hooks or _global_forward_pre_hooks):
-> [1784](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py:1784) return forward_call(*args, **kwargs)
1786 result = None
1787 called_always_called_hooks = set()
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:315, in Metric.forward(self, *args, **kwargs)
313 self._forward_cache = self._forward_full_state_update(*args, **kwargs)
314 else:
--> [315](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:315) self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
317 return self._forward_cache
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:384, in Metric._forward_reduce_state_update(self, *args, **kwargs)
381 self._enable_grad = True # allow grads for batch computation
383 # calculate batch state and compute batch value
--> [384](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:384) self.update(*args, **kwargs)
385 batch_val = self.compute()
387 # reduce batch and global state
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:559, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
551 if "Expected all tensors to be on" in str(err):
552 raise RuntimeError(
553 "Encountered different devices in metric calculation (see stacktrace for details)."
554 " This could be due to the metric class not being on the same device as input."
(...) 557 " device corresponds to the device of the input."
558 ) from err
--> [559](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:559) raise err
561 if self.compute_on_cpu:
562 self._move_list_states_to_cpu()
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:549, in Metric._wrap_update.<locals>.wrapped_func(*args, **kwargs)
547 with torch.set_grad_enabled(self._enable_grad):
548 try:
--> [549](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/metric.py:549) update(*args, **kwargs)
550 except RuntimeError as err:
551 if "Expected all tensors to be on" in str(err):
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/classification/precision_recall_curve.py:379, in MulticlassPrecisionRecallCurve.update(self, preds, target)
375 _multiclass_precision_recall_curve_tensor_validation(preds, target, self.num_classes, self.ignore_index)
376 preds, target, _ = _multiclass_precision_recall_curve_format(
377 preds, target, self.num_classes, self.thresholds, self.ignore_index, self.average
378 )
--> [379](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/classification/precision_recall_curve.py:379) state = _multiclass_precision_recall_curve_update(
380 preds, target, self.num_classes, self.thresholds, self.average
381 )
382 if isinstance(state, Tensor):
383 self.confmat += state
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:486, in _multiclass_precision_recall_curve_update(preds, target, num_classes, thresholds, average)
484 else:
485 update_fn = _multiclass_precision_recall_curve_update_loop
--> [486](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:486) return update_fn(preds, target, num_classes, thresholds)
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:507, in _multiclass_precision_recall_curve_update_vectorized(preds, target, num_classes, thresholds)
505 unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
506 unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
--> [507](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/functional/classification/precision_recall_curve.py:507) bins = _bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
508 return bins.reshape(len_t, num_classes, 2, 2)
File ~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/utilities/data.py:203, in _bincount(x, minlength)
200 minlength = len(torch.unique(x))
202 if torch.are_deterministic_algorithms_enabled() or _XLA_AVAILABLE or x.is_mps:
--> [203](https://file+.vscode-resource.vscode-cdn.net/Users/pgq596/Programs/marcos/notebooks/~/Programs/marcos/.venv/lib/python3.11/site-packages/torchmetrics/utilities/data.py:203) mesh = torch.arange(minlength, device=x.device).repeat(len(x), 1)
204 return torch.eq(x.reshape(-1, 1), mesh).sum(dim=0)
206 return torch.bincount(x, minlength=minlength)
RuntimeError: Invalid buffer size: 643.73 GiB
Metadata
Metadata
Assignees
Labels
bug / fixSomething isn't workingSomething isn't workinghelp wantedExtra attention is neededExtra attention is needed