Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,7 @@ def _check_and_init_precision(self) -> Precision:
rank_zero_info(
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
)
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]

raise RuntimeError("No precision set")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1084,3 +1084,19 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
with error_context:
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())


@RunIf(mps=True)
@pytest.mark.parametrize(
("accelerator", "expected_device"),
[
("mps", "mps"),
("cpu", "cpu"),
],
)
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
def test_mps_amp_device_selection(accelerator, expected_device, precision):
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
connector = _AcceleratorConnector(accelerator=accelerator, precision=precision)
assert isinstance(connector.precision_plugin, MixedPrecision)
assert connector.precision_plugin.device == expected_device
Loading