2222import pytest
2323from jsonargparse import Namespace
2424
25- from lightning .fabric .cli import FabricCLI , _get_supported_strategies , main as _run_main
25+ from lightning .fabric .cli import FabricCLI , _get_supported_strategies
26+ from lightning .fabric .cli import main as _run_main
2627from lightning .fabric .utilities .consolidate_checkpoint import main as _consolidate_main
2728from tests_fabric .helpers .runif import RunIf
2829
@@ -38,7 +39,17 @@ def fake_script(tmp_path):
3839def test_run_env_vars_defaults (monkeypatch , fake_script ):
3940 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
4041 with pytest .raises (SystemExit ) as e :
41- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
42+ args = Namespace (
43+ script = fake_script ,
44+ accelerator = None ,
45+ strategy = None ,
46+ devices = "1" ,
47+ num_nodes = 1 ,
48+ node_rank = 0 ,
49+ main_address = "127.0.0.1" ,
50+ main_port = 29400 ,
51+ precision = None ,
52+ )
4253 _run_main (args )
4354 assert e .value .code == 0
4455 assert os .environ ["LT_CLI_USED" ] == "1"
@@ -55,7 +66,17 @@ def test_run_env_vars_defaults(monkeypatch, fake_script):
5566def test_run_env_vars_accelerator (_ , accelerator , monkeypatch , fake_script ):
5667 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
5768 with pytest .raises (SystemExit ) as e :
58- args = Namespace (script = fake_script , accelerator = accelerator , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
69+ args = Namespace (
70+ script = fake_script ,
71+ accelerator = accelerator ,
72+ strategy = None ,
73+ devices = "1" ,
74+ num_nodes = 1 ,
75+ node_rank = 0 ,
76+ main_address = "127.0.0.1" ,
77+ main_port = 29400 ,
78+ precision = None ,
79+ )
5980 _run_main (args )
6081 assert e .value .code == 0
6182 assert os .environ ["LT_ACCELERATOR" ] == accelerator
@@ -67,7 +88,17 @@ def test_run_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
6788def test_run_env_vars_strategy (_ , strategy , monkeypatch , fake_script ):
6889 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
6990 with pytest .raises (SystemExit ) as e :
70- args = Namespace (script = fake_script , accelerator = None , strategy = strategy , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
91+ args = Namespace (
92+ script = fake_script ,
93+ accelerator = None ,
94+ strategy = strategy ,
95+ devices = "1" ,
96+ num_nodes = 1 ,
97+ node_rank = 0 ,
98+ main_address = "127.0.0.1" ,
99+ main_port = 29400 ,
100+ precision = None ,
101+ )
71102 _run_main (args )
72103 assert e .value .code == 0
73104 assert os .environ ["LT_STRATEGY" ] == strategy
@@ -96,7 +127,17 @@ def test_run_env_vars_unsupported_strategy(strategy, fake_script):
96127def test_run_env_vars_devices_cuda (_ , devices , monkeypatch , fake_script ):
97128 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
98129 with pytest .raises (SystemExit ) as e :
99- args = Namespace (script = fake_script , accelerator = "cuda" , strategy = None , devices = devices , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
130+ args = Namespace (
131+ script = fake_script ,
132+ accelerator = "cuda" ,
133+ strategy = None ,
134+ devices = devices ,
135+ num_nodes = 1 ,
136+ node_rank = 0 ,
137+ main_address = "127.0.0.1" ,
138+ main_port = 29400 ,
139+ precision = None ,
140+ )
100141 _run_main (args )
101142 assert e .value .code == 0
102143 assert os .environ ["LT_DEVICES" ] == devices
@@ -108,7 +149,17 @@ def test_run_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
108149def test_run_env_vars_devices_mps (accelerator , monkeypatch , fake_script ):
109150 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
110151 with pytest .raises (SystemExit ) as e :
111- args = Namespace (script = fake_script , accelerator = accelerator , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
152+ args = Namespace (
153+ script = fake_script ,
154+ accelerator = accelerator ,
155+ strategy = None ,
156+ devices = "1" ,
157+ num_nodes = 1 ,
158+ node_rank = 0 ,
159+ main_address = "127.0.0.1" ,
160+ main_port = 29400 ,
161+ precision = None ,
162+ )
112163 _run_main (args )
113164 assert e .value .code == 0
114165 assert os .environ ["LT_DEVICES" ] == "1"
@@ -119,7 +170,17 @@ def test_run_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
119170def test_run_env_vars_num_nodes (num_nodes , monkeypatch , fake_script ):
120171 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
121172 with pytest .raises (SystemExit ) as e :
122- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = int (num_nodes ), node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
173+ args = Namespace (
174+ script = fake_script ,
175+ accelerator = None ,
176+ strategy = None ,
177+ devices = "1" ,
178+ num_nodes = int (num_nodes ),
179+ node_rank = 0 ,
180+ main_address = "127.0.0.1" ,
181+ main_port = 29400 ,
182+ precision = None ,
183+ )
123184 _run_main (args )
124185 assert e .value .code == 0
125186 assert os .environ ["LT_NUM_NODES" ] == num_nodes
@@ -130,7 +191,17 @@ def test_run_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
130191def test_run_env_vars_precision (precision , monkeypatch , fake_script ):
131192 monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
132193 with pytest .raises (SystemExit ) as e :
133- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = precision )
194+ args = Namespace (
195+ script = fake_script ,
196+ accelerator = None ,
197+ strategy = None ,
198+ devices = "1" ,
199+ num_nodes = 1 ,
200+ node_rank = 0 ,
201+ main_address = "127.0.0.1" ,
202+ main_port = 29400 ,
203+ precision = precision ,
204+ )
134205 _run_main (args )
135206 assert e .value .code == 0
136207 assert os .environ ["LT_PRECISION" ] == precision
@@ -141,7 +212,17 @@ def test_run_torchrun_defaults(monkeypatch, fake_script):
141212 torchrun_mock = Mock ()
142213 monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
143214 with pytest .raises (SystemExit ) as e :
144- args = Namespace (script = fake_script , accelerator = None , strategy = None , devices = "1" , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
215+ args = Namespace (
216+ script = fake_script ,
217+ accelerator = None ,
218+ strategy = None ,
219+ devices = "1" ,
220+ num_nodes = 1 ,
221+ node_rank = 0 ,
222+ main_address = "127.0.0.1" ,
223+ main_port = 29400 ,
224+ precision = None ,
225+ )
145226 _run_main (args )
146227 assert e .value .code == 0
147228 torchrun_mock .main .assert_called_with ([
@@ -170,7 +251,17 @@ def test_run_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
170251 torchrun_mock = Mock ()
171252 monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
172253 with pytest .raises (SystemExit ) as e :
173- args = Namespace (script = fake_script , accelerator = "cuda" , strategy = None , devices = devices , num_nodes = 1 , node_rank = 0 , main_address = "127.0.0.1" , main_port = 29400 , precision = None )
254+ args = Namespace (
255+ script = fake_script ,
256+ accelerator = "cuda" ,
257+ strategy = None ,
258+ devices = devices ,
259+ num_nodes = 1 ,
260+ node_rank = 0 ,
261+ main_address = "127.0.0.1" ,
262+ main_port = 29400 ,
263+ precision = None ,
264+ )
174265 _run_main (args )
175266 assert e .value .code == 0
176267 torchrun_mock .main .assert_called_with ([
0 commit comments