@@ -1087,22 +1087,16 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
1087
1087
1088
1088
1089
1089
@RunIf (mps = True )
1090
+ @pytest .mark .parametrize (
1091
+ ("accelerator" , "expected_device" ),
1092
+ [
1093
+ ("mps" , "mps" ),
1094
+ ("cpu" , "cpu" ),
1095
+ ],
1096
+ )
1090
1097
@pytest .mark .parametrize ("precision" , ["16-mixed" , "bf16-mixed" ])
1091
- def test_mps_amp_device_selection (precision ):
1098
+ def test_mps_amp_device_selection (accelerator , expected_device , precision ):
1092
1099
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
1093
- connector = _AcceleratorConnector (accelerator = "mps" , precision = precision )
1100
+ connector = _AcceleratorConnector (accelerator = accelerator , precision = precision )
1094
1101
assert isinstance (connector .precision_plugin , MixedPrecision )
1095
- # Verify the device parameter is set to "mps" for MPS accelerator
1096
- assert connector .precision_plugin .device == "mps"
1097
-
1098
-
1099
- @RunIf (mps = True )
1100
- def test_mps_amp_device_selection_vs_cpu ():
1101
- """Test that MPS AMP device selection differs from CPU AMP device selection."""
1102
- # MPS with mixed precision should use "mps" device
1103
- mps_connector = _AcceleratorConnector (accelerator = "mps" , precision = "bf16-mixed" )
1104
- assert mps_connector .precision_plugin .device == "mps"
1105
-
1106
- # CPU with mixed precision should use "cpu" device
1107
- cpu_connector = _AcceleratorConnector (accelerator = "cpu" , precision = "bf16-mixed" )
1108
- assert cpu_connector .precision_plugin .device == "cpu"
1102
+ assert connector .precision_plugin .device == expected_device
0 commit comments