From d0ee5ca16b44ce0ea002830a431b49f67665d74b Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 20:47:47 +0000 Subject: [PATCH 1/3] Initial plan From 6a098868eb383222b39b3d2c74ebeb5e06b2341f Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:01:15 +0000 Subject: [PATCH 2/3] Fix race condition in _safe_divide by creating tensor directly on device Co-authored-by: Borda <6035284+Borda@users.noreply.github.com> --- src/torchmetrics/utilities/compute.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/torchmetrics/utilities/compute.py b/src/torchmetrics/utilities/compute.py index 647cfb8371a..9a51f420139 100644 --- a/src/torchmetrics/utilities/compute.py +++ b/src/torchmetrics/utilities/compute.py @@ -74,10 +74,7 @@ def _safe_divide( if zero_division == "warn" and torch.any(denom == 0): rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0") zero_division = 0.0 if zero_division == "warn" else zero_division - # MPS does not support non blocking - zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to( - num.device, non_blocking=num.device.type != "mps" - ) + zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype, device=num.device) return torch.where(denom != 0, num / denom, zero_division_tensor) return torch.true_divide(num, denom) From 7b24786b7cc534905fd08dbcaee41fc7159ca23d Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 4 Oct 2025 21:03:54 +0000 Subject: [PATCH 3/3] Add test for _safe_divide to verify race condition fix Co-authored-by: Borda <6035284+Borda@users.noreply.github.com> --- tests/unittests/utilities/test_utilities.py | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/unittests/utilities/test_utilities.py b/tests/unittests/utilities/test_utilities.py index 226edf1b724..4b90f4cfddf 100644 --- a/tests/unittests/utilities/test_utilities.py +++ b/tests/unittests/utilities/test_utilities.py @@ -242,6 +242,32 @@ def test_half_precision_top_k_cpu_raises_error(): torch.topk(x, k=3, dim=1) +def test_safe_divide(): + """Test that _safe_divide works correctly and doesn't have race conditions.""" + from torchmetrics.utilities.compute import _safe_divide + + # Test basic functionality + num = torch.tensor([1.0, 2.0, 3.0]) + denom = torch.tensor([0.0, 1.0, 2.0]) + result = _safe_divide(num, denom) + expected = torch.tensor([0.0, 2.0, 1.5]) + assert torch.allclose(result, expected) + + # Test custom zero_division value + result = _safe_divide(num, denom, zero_division=99.0) + expected_custom = torch.tensor([99.0, 2.0, 1.5]) + assert torch.allclose(result, expected_custom) + + # Test that result is on the same device as input + for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []): + num_dev = torch.tensor([1.0, 2.0, 3.0], device=device) + denom_dev = torch.tensor([0.0, 1.0, 2.0], device=device) + result = _safe_divide(num_dev, denom_dev) + assert result.device == torch.device(device), f"Result not on correct device: {result.device}" + expected_dev = torch.tensor([0.0, 2.0, 1.5], device=device) + assert torch.allclose(result, expected_dev) + + def find_free_port(start=START_PORT, end=MAX_PORT): """Returns an available localhost port in the given range or returns -1 if no port available.""" for port in range(start, end + 1):