Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4639,6 +4639,30 @@ def test_index_select_stacked_not_supported(self):
torch.index_select(stacked_spec, dim=0, index=torch.tensor([0]))


def _has_mps():
if hasattr(torch, "mps") and hasattr(torch.mps, "is_available"):
return torch.mps.is_available()
return (
getattr(torch.backends, "mps", None) is not None
and torch.backends.mps.is_available()
)


@pytest.mark.skipif(not _has_mps(), reason="MPS device not available")
class TestMPSDtype:
"""Tests that MPS-incompatible dtypes (float64) are downcast to float32 in tensor specs."""

def test_unbounded_to_mps_downcasts_float64(self):
"""Unbounded.to('mps') downcasts float64 -> float32."""
spec_cpu = Unbounded(shape=(6,), device="cpu", dtype=torch.float64)
assert spec_cpu.dtype == torch.float64

with pytest.warns(UserWarning, match="MPS device does not support float64"):
spec_mps = spec_cpu.to("mps")
assert spec_mps.dtype == torch.float32
assert spec_mps.device.type == "mps"


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
9 changes: 9 additions & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,15 @@ def _default_dtype_and_device(
device = _make_ordinal_device(torch.device(device))
elif not allow_none_device:
device = torch.zeros(()).device

if device is not None and device.type == "mps" and dtype == torch.float64:
warnings.warn(
"MPS device does not support float64. Downcasting dtype from float64 to float32.",
UserWarning,
stacklevel=2,
)
dtype = torch.float32

return dtype, device


Expand Down