@@ -4639,6 +4639,39 @@ def test_index_select_stacked_not_supported(self):
46394639 torch .index_select (stacked_spec , dim = 0 , index = torch .tensor ([0 ]))
46404640
46414641
4642+ def _has_mps ():
4643+ if hasattr (torch , "mps" ) and hasattr (torch .mps , "is_available" ):
4644+ return torch .mps .is_available ()
4645+ return (
4646+ getattr (torch .backends , "mps" , None ) is not None
4647+ and torch .backends .mps .is_available ()
4648+ )
4649+
4650+
4651+ @pytest .mark .skipif (not _has_mps (), reason = "MPS device not available" )
4652+ class TestMPSDtype :
4653+ """Tests that MPS-incompatible dtypes (float64) are downcast to float32 in tensor specs."""
4654+
4655+ def test_mps_does_not_support_float64 (self ):
4656+ """Assert that MPS still doesn't support float64.
4657+
4658+ If this test fails, MPS has gained float64 support and the downcasts
4659+ can be removed (e.g., in _default_dtype_and_device)
4660+ """
4661+ with pytest .raises (TypeError , match = "MPS framework doesn't support float64" ):
4662+ torch .ones (2 , dtype = torch .float64 , device = "mps" )
4663+
4664+ def test_unbounded_to_mps_downcasts_float64 (self ):
4665+ """Unbounded.to('mps') downcasts float64 -> float32."""
4666+ spec_cpu = Unbounded (shape = (6 ,), device = "cpu" , dtype = torch .float64 )
4667+ assert spec_cpu .dtype == torch .float64
4668+
4669+ with pytest .warns (UserWarning , match = "MPS device does not support float64" ):
4670+ spec_mps = spec_cpu .to ("mps" )
4671+ assert spec_mps .dtype == torch .float32
4672+ assert spec_mps .device .type == "mps"
4673+
4674+
46424675if __name__ == "__main__" :
46434676 args , unknown = argparse .ArgumentParser ().parse_known_args ()
46444677 pytest .main ([__file__ , "--capture" , "no" , "--exitfirst" ] + unknown )
0 commit comments