Skip to content

Commit 39e7682

Browse files
committed
Add tests for MPS accelerator mixed precision device selection
1 parent 92f0103 commit 39e7682

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,3 +1084,25 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
10841084
error_context = pytest.raises(ValueError, match=f"does not support .*{precision}") if raises else nullcontext()
10851085
with error_context:
10861086
_AcceleratorConnector(precision=precision, strategy=ModelParallelStrategy())
1087+
1088+
1089+
@RunIf(mps=True)
1090+
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
1091+
def test_mps_amp_device_selection(precision):
1092+
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
1093+
connector = _AcceleratorConnector(accelerator="mps", precision=precision)
1094+
assert isinstance(connector.precision_plugin, MixedPrecision)
1095+
# Verify the device parameter is set to "mps" for MPS accelerator
1096+
assert connector.precision_plugin.device == "mps"
1097+
1098+
1099+
@RunIf(mps=True)
1100+
def test_mps_amp_device_selection_vs_cpu():
1101+
"""Test that MPS AMP device selection differs from CPU AMP device selection."""
1102+
# MPS with mixed precision should use "mps" device
1103+
mps_connector = _AcceleratorConnector(accelerator="mps", precision="bf16-mixed")
1104+
assert mps_connector.precision_plugin.device == "mps"
1105+
1106+
# CPU with mixed precision should use "cpu" device
1107+
cpu_connector = _AcceleratorConnector(accelerator="cpu", precision="bf16-mixed")
1108+
assert cpu_connector.precision_plugin.device == "cpu"

0 commit comments

Comments
 (0)