Skip to content

Commit b2d30f1

Browse files
committed
Refactor MPS mixed precision device selection tests for parameterized input
1 parent 39e7682 commit b2d30f1

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,22 +1087,16 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
10871087

10881088

10891089
@RunIf(mps=True)
1090+
@pytest.mark.parametrize(
1091+
("accelerator", "expected_device"),
1092+
[
1093+
("mps", "mps"),
1094+
("cpu", "cpu"),
1095+
],
1096+
)
10901097
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
1091-
def test_mps_amp_device_selection(precision):
1098+
def test_mps_amp_device_selection(accelerator, expected_device, precision):
10921099
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
1093-
connector = _AcceleratorConnector(accelerator="mps", precision=precision)
1100+
connector = _AcceleratorConnector(accelerator=accelerator, precision=precision)
10941101
assert isinstance(connector.precision_plugin, MixedPrecision)
1095-
# Verify the device parameter is set to "mps" for MPS accelerator
1096-
assert connector.precision_plugin.device == "mps"
1097-
1098-
1099-
@RunIf(mps=True)
1100-
def test_mps_amp_device_selection_vs_cpu():
1101-
"""Test that MPS AMP device selection differs from CPU AMP device selection."""
1102-
# MPS with mixed precision should use "mps" device
1103-
mps_connector = _AcceleratorConnector(accelerator="mps", precision="bf16-mixed")
1104-
assert mps_connector.precision_plugin.device == "mps"
1105-
1106-
# CPU with mixed precision should use "cpu" device
1107-
cpu_connector = _AcceleratorConnector(accelerator="cpu", precision="bf16-mixed")
1108-
assert cpu_connector.precision_plugin.device == "cpu"
1102+
assert connector.precision_plugin.device == expected_device

0 commit comments

Comments
 (0)