@@ -46,7 +46,7 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
4646 assert "LT_PRECISION" not in os .environ
4747
4848
49- @pytest .mark .parametrize ("accelerator" , ["cpu" , "gpu" , "cuda" , pytest .param ("mps" , marks = RunIf (mps = True ))])
49+ @pytest .mark .parametrize ("accelerator" , ["cpu" , "gpu" , "cuda" , "auto" , pytest .param ("mps" , marks = RunIf (mps = True ))])
5050@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
5151@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 2 )
5252def test_run_env_vars_accelerator (_ , accelerator , monkeypatch , fake_script ):
@@ -85,7 +85,7 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
8585 assert f"Invalid value for '--strategy': '{ strategy } '" in ioerr .getvalue ()
8686
8787
88- @pytest .mark .parametrize ("devices" , ["1" , "2" , "0," , "1,0" , "-1" ])
88+ @pytest .mark .parametrize ("devices" , ["1" , "2" , "0," , "1,0" , "-1" , "auto" ])
8989@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
9090@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 2 )
9191def test_run_env_vars_devices_cuda (_ , devices , monkeypatch , fake_script ):
@@ -97,7 +97,7 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
9797
9898
9999@RunIf (mps = True )
100- @pytest .mark .parametrize ("accelerator" , ["mps" , "gpu" ])
100+ @pytest .mark .parametrize ("accelerator" , ["mps" , "gpu" , "auto" ])
101101@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
102102def test_run_env_vars_devices_mps (accelerator , monkeypatch , fake_script ):
103103 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
0 commit comments