Skip to content

Commit b6326f7

Browse files
authored
[Bugfix] Add MPS float64->float32 downcast (#3548)
1 parent 1ed0d1e commit b6326f7

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

test/test_specs.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4639,6 +4639,39 @@ def test_index_select_stacked_not_supported(self):
46394639
torch.index_select(stacked_spec, dim=0, index=torch.tensor([0]))
46404640

46414641

4642+
def _has_mps():
4643+
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available"):
4644+
return torch.mps.is_available()
4645+
return (
4646+
getattr(torch.backends, "mps", None) is not None
4647+
and torch.backends.mps.is_available()
4648+
)
4649+
4650+
4651+
@pytest.mark.skipif(not _has_mps(), reason="MPS device not available")
4652+
class TestMPSDtype:
4653+
"""Tests that MPS-incompatible dtypes (float64) are downcast to float32 in tensor specs."""
4654+
4655+
def test_mps_does_not_support_float64(self):
4656+
"""Assert that MPS still doesn't support float64.
4657+
4658+
If this test fails, MPS has gained float64 support and the downcasts
4659+
can be removed (e.g., in _default_dtype_and_device)
4660+
"""
4661+
with pytest.raises(TypeError, match="MPS framework doesn't support float64"):
4662+
torch.ones(2, dtype=torch.float64, device="mps")
4663+
4664+
def test_unbounded_to_mps_downcasts_float64(self):
4665+
"""Unbounded.to('mps') downcasts float64 -> float32."""
4666+
spec_cpu = Unbounded(shape=(6,), device="cpu", dtype=torch.float64)
4667+
assert spec_cpu.dtype == torch.float64
4668+
4669+
with pytest.warns(UserWarning, match="MPS device does not support float64"):
4670+
spec_mps = spec_cpu.to("mps")
4671+
assert spec_mps.dtype == torch.float32
4672+
assert spec_mps.device.type == "mps"
4673+
4674+
46424675
if __name__ == "__main__":
46434676
args, unknown = argparse.ArgumentParser().parse_known_args()
46444677
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/tensor_specs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,15 @@ def _default_dtype_and_device(
113113
device = _make_ordinal_device(torch.device(device))
114114
elif not allow_none_device:
115115
device = torch.zeros(()).device
116+
117+
if device is not None and device.type == "mps" and dtype == torch.float64:
118+
warnings.warn(
119+
"MPS device does not support float64. Downcasting dtype from float64 to float32.",
120+
UserWarning,
121+
stacklevel=2,
122+
)
123+
dtype = torch.float32
124+
116125
return dtype, device
117126

118127

0 commit comments

Comments
 (0)