diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 647cfb8371a..9a51f420139 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -74,10 +74,7 @@ def _safe_divide( if zero_division == "warn" and torch.any(denom == 0): rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0") zero_division = 0.0 if zero_division == "warn" else zero_division - # MPS does not support non blocking - zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to( - num.device, non_blocking=num.device.type != "mps" - ) + zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype, device=num.device) return torch.where(denom != 0, num / denom, zero_division_tensor) return torch.true_divide(num, denom) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index 226edf1b724..4b90f4cfddf 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -242,6 +242,32 @@ def test_half_precision_top_k_cpu_raises_error(): torch.topk(x, k=3, dim=1) +def test_safe_divide(): + """Test that _safe_divide works correctly and doesn't have race conditions.""" + from torchmetrics.utilities.compute import _safe_divide + + # Test basic functionality + num = torch.tensor([1.0, 2.0, 3.0]) + denom = torch.tensor([0.0, 1.0, 2.0]) + result = _safe_divide(num, denom) + expected = torch.tensor([0.0, 2.0, 1.5]) + assert torch.allclose(result, expected) + + # Test custom zero_division value + result = _safe_divide(num, denom, zero_division=99.0) + expected_custom = torch.tensor([99.0, 2.0, 1.5]) + assert torch.allclose(result, expected_custom) + + # Test that result is on the same device as input + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + num_dev = torch.tensor([1.0, 2.0, 3.0], device=device) + denom_dev = torch.tensor([0.0, 1.0, 2.0], device=device) + result = _safe_divide(num_dev, denom_dev) + assert result.device == torch.device(device), f"Result not on correct device: {result.device}" + expected_dev = torch.tensor([0.0, 2.0, 1.5], device=device) + assert torch.allclose(result, expected_dev) + + def find_free_port(start=START_PORT, end=MAX_PORT): """Returns an available localhost port in the given range or returns -1 if no port available.""" for port in range(start, end + 1):