Skip to content

Commit fadc7b8

Browse files
committed
Add test to make sure mps still does not support float64
1 parent 471bfa5 commit fadc7b8

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

test/test_specs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4652,6 +4652,15 @@ def _has_mps():
46524652
class TestMPSDtype:
46534653
"""Tests that MPS-incompatible dtypes (float64) are downcast to float32 in tensor specs."""
46544654

4655+
def test_mps_does_not_support_float64(self):
4656+
"""Assert that MPS still doesn't support float64.
4657+
4658+
If this test fails, MPS has gained float64 support and the downcasts
4659+
can be removed (e.g., in _default_dtype_and_device)
4660+
"""
4661+
with pytest.raises(TypeError, match="MPS framework doesn't support float64"):
4662+
torch.ones(2, dtype=torch.float64, device="mps")
4663+
46554664
def test_unbounded_to_mps_downcasts_float64(self):
46564665
"""Unbounded.to('mps') downcasts float64 -> float32."""
46574666
spec_cpu = Unbounded(shape=(6,), device="cpu", dtype=torch.float64)

0 commit comments

Comments
 (0)