Skip to content

Commit 4db4bc4

Browse files
Apply suggestions
Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent e7ba39c commit 4db4bc4

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/tests_pytorch/trainer/connectors/test_accelerator_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,8 +1095,8 @@ def test_precision_selection_model_parallel(precision, raises, mps_count_0):
10951095
],
10961096
)
10971097
@pytest.mark.parametrize("precision", ["16-mixed", "bf16-mixed"])
1098-
def test_mps_amp_device_selection(accelerator, expected_device, precision):
1098+
def test_mps_amp_device_selection(accelerator, precision):
10991099
"""Test that MPS accelerator with mixed precision correctly sets device to 'mps' instead of 'cuda'."""
11001100
connector = _AcceleratorConnector(accelerator=accelerator, precision=precision)
11011101
assert isinstance(connector.precision_plugin, MixedPrecision)
1102-
assert connector.precision_plugin.device == expected_device
1102+
assert connector.precision_plugin.device == accelerator

0 commit comments

Comments
 (0)