20
20
from unittest .mock import Mock
21
21
22
22
import pytest
23
- from lightning .fabric .cli import _get_supported_strategies , _run
23
+ from lightning .fabric .cli import _consolidate , _get_supported_strategies , _run
24
24
25
25
from tests_fabric .helpers .runif import RunIf
26
26
@@ -33,7 +33,7 @@ def fake_script(tmp_path):
33
33
34
34
35
35
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
36
- def test_cli_env_vars_defaults (monkeypatch , fake_script ):
36
+ def test_run_env_vars_defaults (monkeypatch , fake_script ):
37
37
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
38
38
with pytest .raises (SystemExit ) as e :
39
39
_run .main ([fake_script ])
@@ -49,7 +49,7 @@ def test_cli_env_vars_defaults(monkeypatch, fake_script):
49
49
@pytest .mark .parametrize ("accelerator" , ["cpu" , "gpu" , "cuda" , pytest .param ("mps" , marks = RunIf (mps = True ))])
50
50
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
51
51
@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 2 )
52
- def test_cli_env_vars_accelerator (_ , accelerator , monkeypatch , fake_script ):
52
+ def test_run_env_vars_accelerator (_ , accelerator , monkeypatch , fake_script ):
53
53
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
54
54
with pytest .raises (SystemExit ) as e :
55
55
_run .main ([fake_script , "--accelerator" , accelerator ])
@@ -60,23 +60,23 @@ def test_cli_env_vars_accelerator(_, accelerator, monkeypatch, fake_script):
60
60
@pytest .mark .parametrize ("strategy" , _get_supported_strategies ())
61
61
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
62
62
@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 2 )
63
- def test_cli_env_vars_strategy (_ , strategy , monkeypatch , fake_script ):
63
+ def test_run_env_vars_strategy (_ , strategy , monkeypatch , fake_script ):
64
64
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
65
65
with pytest .raises (SystemExit ) as e :
66
66
_run .main ([fake_script , "--strategy" , strategy ])
67
67
assert e .value .code == 0
68
68
assert os .environ ["LT_STRATEGY" ] == strategy
69
69
70
70
71
- def test_cli_get_supported_strategies ():
71
+ def test_run_get_supported_strategies ():
72
72
"""Test to ensure that when new strategies get added, we must consider updating the list of supported ones in the
73
73
CLI."""
74
74
assert len (_get_supported_strategies ()) == 7
75
75
assert "fsdp" in _get_supported_strategies ()
76
76
77
77
78
78
@pytest .mark .parametrize ("strategy" , ["ddp_spawn" , "ddp_fork" , "ddp_notebook" , "deepspeed_stage_3_offload" ])
79
- def test_cli_env_vars_unsupported_strategy (strategy , fake_script ):
79
+ def test_run_env_vars_unsupported_strategy (strategy , fake_script ):
80
80
ioerr = StringIO ()
81
81
with pytest .raises (SystemExit ) as e , contextlib .redirect_stderr (ioerr ):
82
82
_run .main ([fake_script , "--strategy" , strategy ])
@@ -87,7 +87,7 @@ def test_cli_env_vars_unsupported_strategy(strategy, fake_script):
87
87
@pytest .mark .parametrize ("devices" , ["1" , "2" , "0," , "1,0" , "-1" ])
88
88
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
89
89
@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 2 )
90
- def test_cli_env_vars_devices_cuda (_ , devices , monkeypatch , fake_script ):
90
+ def test_run_env_vars_devices_cuda (_ , devices , monkeypatch , fake_script ):
91
91
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
92
92
with pytest .raises (SystemExit ) as e :
93
93
_run .main ([fake_script , "--accelerator" , "cuda" , "--devices" , devices ])
@@ -98,7 +98,7 @@ def test_cli_env_vars_devices_cuda(_, devices, monkeypatch, fake_script):
98
98
@RunIf (mps = True )
99
99
@pytest .mark .parametrize ("accelerator" , ["mps" , "gpu" ])
100
100
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
101
- def test_cli_env_vars_devices_mps (accelerator , monkeypatch , fake_script ):
101
+ def test_run_env_vars_devices_mps (accelerator , monkeypatch , fake_script ):
102
102
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
103
103
with pytest .raises (SystemExit ) as e :
104
104
_run .main ([fake_script , "--accelerator" , accelerator ])
@@ -108,7 +108,7 @@ def test_cli_env_vars_devices_mps(accelerator, monkeypatch, fake_script):
108
108
109
109
@pytest .mark .parametrize ("num_nodes" , ["1" , "2" , "3" ])
110
110
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
111
- def test_cli_env_vars_num_nodes (num_nodes , monkeypatch , fake_script ):
111
+ def test_run_env_vars_num_nodes (num_nodes , monkeypatch , fake_script ):
112
112
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
113
113
with pytest .raises (SystemExit ) as e :
114
114
_run .main ([fake_script , "--num-nodes" , num_nodes ])
@@ -118,7 +118,7 @@ def test_cli_env_vars_num_nodes(num_nodes, monkeypatch, fake_script):
118
118
119
119
@pytest .mark .parametrize ("precision" , ["64-true" , "64" , "32-true" , "32" , "16-mixed" , "bf16-mixed" ])
120
120
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
121
- def test_cli_env_vars_precision (precision , monkeypatch , fake_script ):
121
+ def test_run_env_vars_precision (precision , monkeypatch , fake_script ):
122
122
monkeypatch .setitem (sys .modules , "torch.distributed.run" , Mock ())
123
123
with pytest .raises (SystemExit ) as e :
124
124
_run .main ([fake_script , "--precision" , precision ])
@@ -127,7 +127,7 @@ def test_cli_env_vars_precision(precision, monkeypatch, fake_script):
127
127
128
128
129
129
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
130
- def test_cli_torchrun_defaults (monkeypatch , fake_script ):
130
+ def test_run_torchrun_defaults (monkeypatch , fake_script ):
131
131
torchrun_mock = Mock ()
132
132
monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
133
133
with pytest .raises (SystemExit ) as e :
@@ -155,7 +155,7 @@ def test_cli_torchrun_defaults(monkeypatch, fake_script):
155
155
)
156
156
@mock .patch .dict (os .environ , os .environ .copy (), clear = True )
157
157
@mock .patch ("lightning.fabric.accelerators.cuda.num_cuda_devices" , return_value = 5 )
158
- def test_cli_torchrun_num_processes_launched (_ , devices , expected , monkeypatch , fake_script ):
158
+ def test_run_torchrun_num_processes_launched (_ , devices , expected , monkeypatch , fake_script ):
159
159
torchrun_mock = Mock ()
160
160
monkeypatch .setitem (sys .modules , "torch.distributed.run" , torchrun_mock )
161
161
with pytest .raises (SystemExit ) as e :
@@ -171,15 +171,15 @@ def test_cli_torchrun_num_processes_launched(_, devices, expected, monkeypatch,
171
171
])
172
172
173
173
174
- def test_cli_through_fabric_entry_point ():
174
+ def test_run_through_fabric_entry_point ():
175
175
result = subprocess .run ("fabric run --help" , capture_output = True , text = True , shell = True )
176
176
177
177
message = "Usage: fabric run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
178
178
assert message in result .stdout or message in result .stderr
179
179
180
180
181
181
@pytest .mark .skipif ("lightning.fabric" == "lightning_fabric" , reason = "standalone package" )
182
- def test_cli_through_lightning_entry_point ():
182
+ def test_run_through_lightning_entry_point ():
183
183
result = subprocess .run ("lightning run model --help" , capture_output = True , text = True , shell = True )
184
184
185
185
deprecation_message = (
@@ -189,3 +189,22 @@ def test_cli_through_lightning_entry_point():
189
189
message = "Usage: lightning run [OPTIONS] SCRIPT [SCRIPT_ARGS]"
190
190
assert deprecation_message in result .stdout
191
191
assert message in result .stdout or message in result .stderr
192
+
193
+
194
+ @mock .patch ("lightning.fabric.cli._process_cli_args" )
195
+ @mock .patch ("lightning.fabric.cli._load_distributed_checkpoint" )
196
+ @mock .patch ("lightning.fabric.cli.torch.save" )
197
+ def test_consolidate (save_mock , _ , __ , tmp_path ):
198
+ ioerr = StringIO ()
199
+ with pytest .raises (SystemExit ) as e , contextlib .redirect_stderr (ioerr ):
200
+ _consolidate .main (["not exist" ])
201
+ assert e .value .code == 2
202
+ assert "Path 'not exist' does not exist" in ioerr .getvalue ()
203
+
204
+ checkpoint_folder = tmp_path / "checkpoint"
205
+ checkpoint_folder .mkdir ()
206
+ ioerr = StringIO ()
207
+ with pytest .raises (SystemExit ) as e , contextlib .redirect_stderr (ioerr ):
208
+ _consolidate .main ([str (checkpoint_folder )])
209
+ assert e .value .code == 0
210
+ save_mock .assert_called_once ()
0 commit comments