Skip to content

Commit 305033e

Browse files
mauvilsalexierule
authored andcommitted
Removed from_argparse_args tests in test_cli.py (#14597)
1 parent f7acd4e commit 305033e

File tree

1 file changed

+0
-155
lines changed

1 file changed

+0
-155
lines changed

tests/tests_pytorch/test_cli.py

Lines changed: 0 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,6 @@
1414
import inspect
1515
import json
1616
import os
17-
import pickle
18-
import sys
19-
from argparse import Namespace
2017
from contextlib import contextmanager, ExitStack, redirect_stdout
2118
from io import StringIO
2219
from typing import Callable, List, Optional, Union
@@ -46,7 +43,6 @@
4643
from pytorch_lightning.plugins.environments import SLURMEnvironment
4744
from pytorch_lightning.strategies import DDPStrategy
4845
from pytorch_lightning.trainer.states import TrainerFn
49-
from pytorch_lightning.utilities import _TPU_AVAILABLE
5046
from pytorch_lightning.utilities.exceptions import MisconfigurationException
5147
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
5248
from tests_pytorch.helpers.runif import RunIf
@@ -67,42 +63,6 @@ def mock_subclasses(baseclass, *subclasses):
6763
yield None
6864

6965

70-
@mock.patch("argparse.ArgumentParser.parse_args")
71-
def test_default_args(mock_argparse):
72-
"""Tests default argument parser for Trainer."""
73-
mock_argparse.return_value = Namespace(**Trainer.default_attributes())
74-
75-
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
76-
args = parser.parse_args([])
77-
78-
args.max_epochs = 5
79-
trainer = Trainer.from_argparse_args(args)
80-
81-
assert isinstance(trainer, Trainer)
82-
assert trainer.max_epochs == 5
83-
84-
85-
@pytest.mark.parametrize("cli_args", [["--accumulate_grad_batches=22"], []])
86-
def test_add_argparse_args_redefined(cli_args):
87-
"""Redefines some default Trainer arguments via the cli and tests the Trainer initialization correctness."""
88-
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
89-
parser.add_lightning_class_args(Trainer, None)
90-
91-
args = parser.parse_args(cli_args)
92-
93-
# make sure we can pickle args
94-
pickle.dumps(args)
95-
96-
# Check few deprecated args are not in namespace:
97-
for depr_name in ("gradient_clip", "nb_gpu_nodes", "max_nb_epochs"):
98-
assert depr_name not in args
99-
100-
trainer = Trainer.from_argparse_args(args=args)
101-
pickle.dumps(trainer)
102-
103-
assert isinstance(trainer, Trainer)
104-
105-
10666
@pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]])
10767
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
10868
"""Asserts error raised in case of passing not default cli arguments."""
@@ -122,121 +82,6 @@ def _raise():
12282
parser.parse_args(cli_args)
12383

12484

125-
@pytest.mark.parametrize(
126-
["cli_args", "expected"],
127-
[
128-
("--auto_lr_find=True --auto_scale_batch_size=power", dict(auto_lr_find=True, auto_scale_batch_size="power")),
129-
(
130-
"--auto_lr_find any_string --auto_scale_batch_size ON",
131-
dict(auto_lr_find="any_string", auto_scale_batch_size=True),
132-
),
133-
("--auto_lr_find=Yes --auto_scale_batch_size=On", dict(auto_lr_find=True, auto_scale_batch_size=True)),
134-
("--auto_lr_find Off --auto_scale_batch_size No", dict(auto_lr_find=False, auto_scale_batch_size=False)),
135-
("--auto_lr_find TRUE --auto_scale_batch_size FALSE", dict(auto_lr_find=True, auto_scale_batch_size=False)),
136-
("--tpu_cores=8", dict(tpu_cores=8)),
137-
("--tpu_cores=1,", dict(tpu_cores="1,")),
138-
("--limit_train_batches=100", dict(limit_train_batches=100)),
139-
("--limit_train_batches 0.8", dict(limit_train_batches=0.8)),
140-
("--enable_model_summary FALSE", dict(enable_model_summary=False)),
141-
(
142-
"",
143-
dict(
144-
# These parameters are marked as Optional[...] in Trainer.__init__,
145-
# with None as default. They should not be changed by the argparse
146-
# interface.
147-
min_steps=None,
148-
accelerator=None,
149-
profiler=None,
150-
),
151-
),
152-
],
153-
)
154-
def test_parse_args_parsing(cli_args, expected):
155-
"""Test parsing simple types and None optionals not modified."""
156-
cli_args = cli_args.split(" ") if cli_args else []
157-
with mock.patch("sys.argv", ["any.py"] + cli_args):
158-
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
159-
parser.add_lightning_class_args(Trainer, None)
160-
args = parser.parse_args()
161-
162-
for k, v in expected.items():
163-
assert getattr(args, k) == v
164-
if "tpu_cores" not in expected or _TPU_AVAILABLE:
165-
assert Trainer.from_argparse_args(args)
166-
167-
168-
@pytest.mark.parametrize(
169-
["cli_args", "expected", "instantiate"],
170-
[
171-
(["--gpus", "[0, 2]"], dict(gpus=[0, 2]), False),
172-
(["--tpu_cores=[1,3]"], dict(tpu_cores=[1, 3]), False),
173-
(['--accumulate_grad_batches={"5":3,"10":20}'], dict(accumulate_grad_batches={5: 3, 10: 20}), True),
174-
],
175-
)
176-
def test_parse_args_parsing_complex_types(cli_args, expected, instantiate):
177-
"""Test parsing complex types."""
178-
with mock.patch("sys.argv", ["any.py"] + cli_args):
179-
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
180-
parser.add_lightning_class_args(Trainer, None)
181-
args = parser.parse_args()
182-
183-
for k, v in expected.items():
184-
assert getattr(args, k) == v
185-
if instantiate:
186-
assert Trainer.from_argparse_args(args)
187-
188-
189-
@pytest.mark.parametrize(
190-
["cli_args", "expected_gpu"],
191-
[
192-
("--accelerator gpu --devices 1", [0]),
193-
("--accelerator gpu --devices 0,", [0]),
194-
("--accelerator gpu --devices 1,", [1]),
195-
("--accelerator gpu --devices 0,1", [0, 1]),
196-
],
197-
)
198-
def test_parse_args_parsing_gpus(monkeypatch, cli_args, expected_gpu):
199-
"""Test parsing of gpus and instantiation of Trainer."""
200-
monkeypatch.setattr("pytorch_lightning.utilities.device_parser.num_cuda_devices", lambda: 2)
201-
monkeypatch.setattr("pytorch_lightning.utilities.device_parser.is_cuda_available", lambda: True)
202-
cli_args = cli_args.split(" ") if cli_args else []
203-
with mock.patch("sys.argv", ["any.py"] + cli_args):
204-
parser = LightningArgumentParser(add_help=False, parse_as_dict=False)
205-
parser.add_lightning_class_args(Trainer, None)
206-
args = parser.parse_args()
207-
208-
trainer = Trainer.from_argparse_args(args)
209-
assert trainer.device_ids == expected_gpu
210-
211-
212-
@pytest.mark.skipif(
213-
sys.version_info < (3, 7),
214-
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec",
215-
)
216-
@pytest.mark.parametrize(
217-
["cli_args", "extra_args"],
218-
[
219-
({}, {}),
220-
(dict(logger=False), {}),
221-
(dict(logger=False), dict(logger=True)),
222-
(dict(logger=False), dict(enable_checkpointing=True)),
223-
],
224-
)
225-
def test_init_from_argparse_args(cli_args, extra_args):
226-
unknown_args = dict(unknown_arg=0)
227-
228-
# unknown args in the argparser/namespace should be ignored
229-
with mock.patch("pytorch_lightning.Trainer.__init__", autospec=True, return_value=None) as init:
230-
trainer = Trainer.from_argparse_args(Namespace(**cli_args, **unknown_args), **extra_args)
231-
expected = dict(cli_args)
232-
expected.update(extra_args) # extra args should override any cli arg
233-
init.assert_called_with(trainer, **expected)
234-
235-
# passing in unknown manual args should throw an error
236-
with pytest.raises(TypeError, match=r"__init__\(\) got an unexpected keyword argument 'unknown_arg'"):
237-
Trainer.from_argparse_args(Namespace(**cli_args), **extra_args, **unknown_args)
238-
239-
24085
class Model(LightningModule):
24186
def __init__(self, model_param: int):
24287
super().__init__()

0 commit comments

Comments
 (0)