Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/torchmetrics/utilities/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,7 @@ def _safe_divide(
if zero_division == "warn" and torch.any(denom == 0):
rank_zero_warn("Detected zero division in _safe_divide. Setting 0/0 to 0.0")
zero_division = 0.0 if zero_division == "warn" else zero_division
# MPS does not support non blocking
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype).to(
num.device, non_blocking=num.device.type != "mps"
)
zero_division_tensor = torch.tensor(zero_division, dtype=num.dtype, device=num.device)
return torch.where(denom != 0, num / denom, zero_division_tensor)
return torch.true_divide(num, denom)

Expand Down
26 changes: 26 additions & 0 deletions tests/unittests/utilities/test_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,32 @@ def test_half_precision_top_k_cpu_raises_error():
torch.topk(x, k=3, dim=1)


def test_safe_divide():
"""Test that _safe_divide works correctly and doesn't have race conditions."""
from torchmetrics.utilities.compute import _safe_divide

# Test basic functionality
num = torch.tensor([1.0, 2.0, 3.0])
denom = torch.tensor([0.0, 1.0, 2.0])
result = _safe_divide(num, denom)
expected = torch.tensor([0.0, 2.0, 1.5])
assert torch.allclose(result, expected)

# Test custom zero_division value
result = _safe_divide(num, denom, zero_division=99.0)
expected_custom = torch.tensor([99.0, 2.0, 1.5])
assert torch.allclose(result, expected_custom)

# Test that result is on the same device as input
for device in ["cpu"] + (["cuda"] if torch.cuda.is_available() else []):
num_dev = torch.tensor([1.0, 2.0, 3.0], device=device)
denom_dev = torch.tensor([0.0, 1.0, 2.0], device=device)
result = _safe_divide(num_dev, denom_dev)
assert result.device == torch.device(device), f"Result not on correct device: {result.device}"
expected_dev = torch.tensor([0.0, 2.0, 1.5], device=device)
assert torch.allclose(result, expected_dev)


def find_free_port(start=START_PORT, end=MAX_PORT):
"""Returns an available localhost port in the given range or returns -1 if no port available."""
for port in range(start, end + 1):
Expand Down
Loading