diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b3ed00611c021..863a3a4a7e939 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896)) +- Add MPS accelerator support for mixed precision ([#21209](https://github.com/Lightning-AI/pytorch-lightning/pull/21209)) + + ### Removed - diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 7f44de0589938..f42af42edf42f 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -515,7 +515,7 @@ def _check_and_init_precision(self) -> Precision: rank_zero_info( f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)" ) - device = "cpu" if self._accelerator_flag == "cpu" else "cuda" + device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda" return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type] raise RuntimeError("No precision set") diff --git a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py index f3d98cf444c36..22641244f2e94 100644 --- a/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py +++ b/tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py @@ -1084,3 +1084,13 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0): error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext() with error_context: _AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy()) + + +@RunIf(mps=True) +@pytest.mark.parametrize("accelerator", ["mps", "cpu"]) +@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"]) +def test_mps_amp_device_selection(accelerator, precision): + """Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'.""" + connector = _AcceleratorConnector(accelerator=accelerator, precision=precision) + assert isinstance(connector.precision_plugin, MixedPrecision) + assert connector.precision_plugin.device == accelerator