Skip to content

Commit cf0960f

Browse files
committed
apply suggestions
1 parent 4db4bc4 commit cf0960f

File tree

1 file changed

+1
-7
lines changed

1 file changed

+1
-7
lines changed

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,13 +1087,7 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
10871087

10881088

10891089
@RunIf(mps=True)
1090-
@pytest.mark.parametrize(
1091-
("accelerator", "expected_device"),
1092-
[
1093-
("mps", "mps"),
1094-
("cpu", "cpu"),
1095-
],
1096-
)
1090+
@pytest.mark.parametrize("accelerator", ["mps", "cpu"])
10971091
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
10981092
def test_mps_amp_device_selection(accelerator, precision):
10991093
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""

0 commit comments

Comments
 (0)