File tree Expand file tree Collapse file tree 1 file changed +9
-0
lines changed
Expand file tree Collapse file tree 1 file changed +9
-0
lines changed Original file line number Diff line number Diff line change @@ -4652,6 +4652,15 @@ def _has_mps():
46524652class TestMPSDtype :
46534653 """Tests that MPS-incompatible dtypes (float64) are downcast to float32 in tensor specs."""
46544654
4655+ def test_mps_does_not_support_float64 (self ):
4656+ """Assert that MPS still doesn't support float64.
4657+
4658+ If this test fails, MPS has gained float64 support and the downcasts
4659+ can be removed (e.g., in _default_dtype_and_device)
4660+ """
4661+ with pytest .raises (TypeError , match = "MPS framework doesn't support float64" ):
4662+ torch .ones (2 , dtype = torch .float64 , device = "mps" )
4663+
46554664 def test_unbounded_to_mps_downcasts_float64 (self ):
46564665 """Unbounded.to('mps') downcasts float64 -> float32."""
46574666 spec_cpu = Unbounded (shape = (6 ,), device = "cpu" , dtype = torch .float64 )
You can’t perform that action at this time.
0 commit comments