Skip to content

Commit 92b13a8

Browse files
committed
Support for MPS device for mixed-precision
1 parent 76f0c54 commit 92b13a8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/lightning/pytorch/trainer/connectors/accelerator_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -520,7 +520,7 @@ def _check_and_init_precision(self) -> Precision:
520520
rank_zero_info(
521521
f"Using {'16bit' if self._precision_flag == '16-mixed' else 'bfloat16'} Automatic Mixed Precision (AMP)"
522522
)
523-
device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
523+
device = "cpu" if self._accelerator_flag == "cpu" else "mps" if self._accelerator_flag == "mps" else "cuda"
524524
return MixedPrecision(self._precision_flag, device) # type: ignore[arg-type]
525525

526526
raise RuntimeError("No precision set")

0 commit comments

Comments
 (0)