1414import inspect
1515import json
1616import os
17- import pickle
18- import sys
19- from argparse import Namespace
2017from contextlib import contextmanager , ExitStack , redirect_stdout
2118from io import StringIO
2219from typing import Callable , List , Optional , Union
4643from pytorch_lightning .plugins .environments import SLURMEnvironment
4744from pytorch_lightning .strategies import DDPStrategy
4845from pytorch_lightning .trainer .states import TrainerFn
49- from pytorch_lightning .utilities import _TPU_AVAILABLE
5046from pytorch_lightning .utilities .exceptions import MisconfigurationException
5147from pytorch_lightning .utilities .imports import _TORCHVISION_AVAILABLE
5248from tests_pytorch .helpers .runif import RunIf
@@ -67,42 +63,6 @@ def mock_subclasses(baseclass, *subclasses):
6763 yield None
6864
6965
70- @mock .patch ("argparse.ArgumentParser.parse_args" )
71- def test_default_args (mock_argparse ):
72- """Tests default argument parser for Trainer."""
73- mock_argparse .return_value = Namespace (** Trainer .default_attributes ())
74-
75- parser = LightningArgumentParser (add_help = False , parse_as_dict = False )
76- args = parser .parse_args ([])
77-
78- args .max_epochs = 5
79- trainer = Trainer .from_argparse_args (args )
80-
81- assert isinstance (trainer , Trainer )
82- assert trainer .max_epochs == 5
83-
84-
85- @pytest .mark .parametrize ("cli_args" , [["--accumulate_grad_batches=22" ], []])
86- def test_add_argparse_args_redefined (cli_args ):
87- """Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
88- parser = LightningArgumentParser (add_help = False , parse_as_dict = False )
89- parser .add_lightning_class_args (Trainer , None )
90-
91- args = parser .parse_args (cli_args )
92-
93- # make sure we can pickle args
94- pickle .dumps (args )
95-
96- # Check few deprecated args are not in namespace:
97- for depr_name in ("gradient_clip" , "nb_gpu_nodes" , "max_nb_epochs" ):
98- assert depr_name not in args
99-
100- trainer = Trainer .from_argparse_args (args = args )
101- pickle .dumps (trainer )
102-
103- assert isinstance (trainer , Trainer )
104-
105-
10666@pytest .mark .parametrize ("cli_args" , [["--callbacks=1" , "--logger" ], ["--foo" , "--bar=1" ]])
10767def test_add_argparse_args_redefined_error (cli_args , monkeypatch ):
10868 """Asserts error raised in case of passing not default cli arguments."""
@@ -122,121 +82,6 @@ def _raise():
12282 parser .parse_args (cli_args )
12383
12484
125- @pytest .mark .parametrize (
126- ["cli_args" , "expected" ],
127- [
128- ("--auto_lr_find=True --auto_scale_batch_size=power" , dict (auto_lr_find = True , auto_scale_batch_size = "power" )),
129- (
130- "--auto_lr_find any_string --auto_scale_batch_size ON" ,
131- dict (auto_lr_find = "any_string" , auto_scale_batch_size = True ),
132- ),
133- ("--auto_lr_find=Yes --auto_scale_batch_size=On" , dict (auto_lr_find = True , auto_scale_batch_size = True )),
134- ("--auto_lr_find Off --auto_scale_batch_size No" , dict (auto_lr_find = False , auto_scale_batch_size = False )),
135- ("--auto_lr_find TRUE --auto_scale_batch_size FALSE" , dict (auto_lr_find = True , auto_scale_batch_size = False )),
136- ("--tpu_cores=8" , dict (tpu_cores = 8 )),
137- ("--tpu_cores=1," , dict (tpu_cores = "1," )),
138- ("--limit_train_batches=100" , dict (limit_train_batches = 100 )),
139- ("--limit_train_batches 0.8" , dict (limit_train_batches = 0.8 )),
140- ("--enable_model_summary FALSE" , dict (enable_model_summary = False )),
141- (
142- "" ,
143- dict (
144- # These parameters are marked as Optional[...] in Trainer.__init__,
145- # with None as default. They should not be changed by the argparse
146- # interface.
147- min_steps = None ,
148- accelerator = None ,
149- profiler = None ,
150- ),
151- ),
152- ],
153- )
154- def test_parse_args_parsing (cli_args , expected ):
155- """Test parsing simple types and None optionals not modified."""
156- cli_args = cli_args .split (" " ) if cli_args else []
157- with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
158- parser = LightningArgumentParser (add_help = False , parse_as_dict = False )
159- parser .add_lightning_class_args (Trainer , None )
160- args = parser .parse_args ()
161-
162- for k , v in expected .items ():
163- assert getattr (args , k ) == v
164- if "tpu_cores" not in expected or _TPU_AVAILABLE :
165- assert Trainer .from_argparse_args (args )
166-
167-
168- @pytest .mark .parametrize (
169- ["cli_args" , "expected" , "instantiate" ],
170- [
171- (["--gpus" , "[0, 2]" ], dict (gpus = [0 , 2 ]), False ),
172- (["--tpu_cores=[1,3]" ], dict (tpu_cores = [1 , 3 ]), False ),
173- (['--accumulate_grad_batches={"5":3,"10":20}' ], dict (accumulate_grad_batches = {5 : 3 , 10 : 20 }), True ),
174- ],
175- )
176- def test_parse_args_parsing_complex_types (cli_args , expected , instantiate ):
177- """Test parsing complex types."""
178- with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
179- parser = LightningArgumentParser (add_help = False , parse_as_dict = False )
180- parser .add_lightning_class_args (Trainer , None )
181- args = parser .parse_args ()
182-
183- for k , v in expected .items ():
184- assert getattr (args , k ) == v
185- if instantiate :
186- assert Trainer .from_argparse_args (args )
187-
188-
189- @pytest .mark .parametrize (
190- ["cli_args" , "expected_gpu" ],
191- [
192- ("--accelerator gpu --devices 1" , [0 ]),
193- ("--accelerator gpu --devices 0," , [0 ]),
194- ("--accelerator gpu --devices 1," , [1 ]),
195- ("--accelerator gpu --devices 0,1" , [0 , 1 ]),
196- ],
197- )
198- def test_parse_args_parsing_gpus (monkeypatch , cli_args , expected_gpu ):
199- """Test parsing of gpus and instantiation of Trainer."""
200- monkeypatch .setattr ("pytorch_lightning.utilities.device_parser.num_cuda_devices" , lambda : 2 )
201- monkeypatch .setattr ("pytorch_lightning.utilities.device_parser.is_cuda_available" , lambda : True )
202- cli_args = cli_args .split (" " ) if cli_args else []
203- with mock .patch ("sys.argv" , ["any.py" ] + cli_args ):
204- parser = LightningArgumentParser (add_help = False , parse_as_dict = False )
205- parser .add_lightning_class_args (Trainer , None )
206- args = parser .parse_args ()
207-
208- trainer = Trainer .from_argparse_args (args )
209- assert trainer .device_ids == expected_gpu
210-
211-
212- @pytest .mark .skipif (
213- sys .version_info < (3 , 7 ),
214- reason = "signature inspection while mocking is not working in Python < 3.7 despite autospec" ,
215- )
216- @pytest .mark .parametrize (
217- ["cli_args" , "extra_args" ],
218- [
219- ({}, {}),
220- (dict (logger = False ), {}),
221- (dict (logger = False ), dict (logger = True )),
222- (dict (logger = False ), dict (enable_checkpointing = True )),
223- ],
224- )
225- def test_init_from_argparse_args (cli_args , extra_args ):
226- unknown_args = dict (unknown_arg = 0 )
227-
228- # unknown args in the argparser/namespace should be ignored
229- with mock .patch ("pytorch_lightning.Trainer.__init__" , autospec = True , return_value = None ) as init :
230- trainer = Trainer .from_argparse_args (Namespace (** cli_args , ** unknown_args ), ** extra_args )
231- expected = dict (cli_args )
232- expected .update (extra_args ) # extra args should override any cli arg
233- init .assert_called_with (trainer , ** expected )
234-
235- # passing in unknown manual args should throw an error
236- with pytest .raises (TypeError , match = r"__init__\(\) got an unexpected keyword argument 'unknown_arg'" ):
237- Trainer .from_argparse_args (Namespace (** cli_args ), ** extra_args , ** unknown_args )
238-
239-
24085class Model (LightningModule ):
24186 def __init__ (self , model_param : int ):
24287 super ().__init__ ()
0 commit comments