Skip to content

Commit d95447a

Browse files
rohitgr7lexierule
authored andcommitted
Update deepspeed precision test (#12727)
1 parent 83e0c4a commit d95447a

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/strategies/test_deepspeed_strategy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,12 +155,13 @@ def test_deepspeed_strategy_env(tmpdir, monkeypatch, deepspeed_config):
155155

156156

157157
@RunIf(deepspeed=True)
158+
@mock.patch("torch.cuda.device_count", return_value=1)
158159
@pytest.mark.parametrize("precision", [16, "mixed"])
159160
@pytest.mark.parametrize(
160161
"amp_backend",
161162
["native", pytest.param("apex", marks=RunIf(amp_apex=True))],
162163
)
163-
def test_deepspeed_precision_choice(amp_backend, precision, tmpdir):
164+
def test_deepspeed_precision_choice(_, amp_backend, precision, tmpdir):
164165
"""Test to ensure precision plugin is also correctly chosen.
165166
166167
DeepSpeed handles precision via Custom DeepSpeedPrecisionPlugin

0 commit comments

Comments
 (0)