Skip to content

Commit 9d40bc8

Browse files
bhimrazyBordaSkafteNicki
authored
Add MPS accelerator support for mixed precision (#21209)
* Update device assignment logic to support 'mps' accelerator * Add tests for MPS accelerator mixed precision device selection * Refactor MPS mixed precision device selection tests for parameterized input * chlog * Apply suggestions Co-authored-by: Nicki Skafte Detlefsen <[email protected]> * apply suggestions * Empty-Commit * Empty Commit * Empty-Commit --------- Co-authored-by: jirka <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent e088694 commit 9d40bc8

File tree

3 files changed

+14
-1
lines changed

3 files changed

+14
-1
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3030
- Default to `RichProgressBar` and `RichModelSummary` if the rich package is available. Fallback to TQDMProgressBar and ModelSummary otherwise ([#20896](https://github.com/Lightning-AI/pytorch-lightning/pull/20896))
3131

3232

33+
- Add MPS accelerator support for mixed precision ([#21209](https://github.com/Lightning-AI/pytorch-lightning/pull/21209))
34+
35+
3336
### Removed
3437

3538
-

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def _check_and_init_precision(self) -> Precision:
515515
rank_zero_info(
516516
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
517517
)
518-
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
518+
device = self._accelerator_flag if self._accelerator_flag in ("cpu", "mps") else "cuda"
519519
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]
520520

521521
raise RuntimeError("No precision set")

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,3 +1084,13 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
10841084
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
10851085
with error_context:
10861086
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())
1087+
1088+
1089+
@RunIf(mps=True)
1090+
@pytest.mark.parametrize("accelerator", ["mps", "cpu"])
1091+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
1092+
def test_mps_amp_device_selection(accelerator, precision):
1093+
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
1094+
connector = _AcceleratorConnector(accelerator=accelerator, precision=precision)
1095+
assert isinstance(connector.precision_plugin, MixedPrecision)
1096+
assert connector.precision_plugin.device == accelerator

0 commit comments

Comments
 (0)