Skip to content

Commit 18cdfab

Browse files
mauvilsacarmoccaakihironitta
authored
Register torch's unresolvable import paths in cli module (Lightning-AI#13153)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent 4f81d6f commit 18cdfab

File tree

4 files changed

+28
-4
lines changed

4 files changed

+28
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
244244
- Fixed issue where the CLI could not pass a `Profiler` to the `Trainer` ([#13084](https://github.com/PyTorchLightning/pytorch-lightning/pull/13084))
245245

246246

247+
- Fixed issue where the CLI fails with certain torch objects ([#13153](https://github.com/PyTorchLightning/pytorch-lightning/pull/13153))
248+
249+
247250
- Fixed logging's step values when multiple dataloaders are used during evaluation ([#12184](https://github.com/PyTorchLightning/pytorch-lightning/pull/12184))
248251

249252

pytorch_lightning/utilities/cli.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,22 @@
3434
from pytorch_lightning.utilities.model_helpers import is_overridden
3535
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
3636

37-
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.7.1")
37+
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.8.0")
3838

3939
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
4040
import docstring_parser
41-
from jsonargparse import ActionConfigFile, ArgumentParser, class_from_function, Namespace, set_config_read_mode
41+
from jsonargparse import (
42+
ActionConfigFile,
43+
ArgumentParser,
44+
class_from_function,
45+
Namespace,
46+
register_unresolvable_import_paths,
47+
set_config_read_mode,
48+
)
4249
from jsonargparse.typehints import get_all_subclass_paths
4350
from jsonargparse.util import import_object
4451

52+
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
4553
set_config_read_mode(fsspec_enabled=True)
4654
else:
4755
locals()["ArgumentParser"] = object

requirements/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ matplotlib>3.1, <3.5.3
44
torchtext>=0.9.*, <=0.12.0
55
omegaconf>=2.0.5, <=2.1.*
66
hydra-core>=1.0.5, <=1.1.*
7-
jsonargparse[signatures]>=4.7.1, <4.7.4
7+
jsonargparse[signatures]>=4.8.0, <=4.8.0
88
gcsfs>=2021.5.0, <=2022.2.0
99
rich>=10.2.2,!=10.15.*, <=12.0.0

tests/utilities/test_cli.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from argparse import Namespace
2121
from contextlib import contextmanager, ExitStack, redirect_stdout
2222
from io import StringIO
23-
from typing import List, Optional, Union
23+
from typing import Callable, List, Optional, Union
2424
from unittest import mock
2525
from unittest.mock import ANY
2626

@@ -1561,3 +1561,16 @@ def test_cli_auto_seeding():
15611561
cli = LightningCLI(TestModel, run=False)
15621562
assert cli.seed_everything_default is True
15631563
assert cli.config["seed_everything"] == 123 # the original seed is kept
1564+
1565+
1566+
def test_unresolvable_import_paths():
1567+
class TestModel(BoringModel):
1568+
def __init__(self, a_func: Callable = torch.softmax):
1569+
super().__init__()
1570+
self.a_func = a_func
1571+
1572+
out = StringIO()
1573+
with mock.patch("sys.argv", ["any.py", "--print_config"]), redirect_stdout(out), pytest.raises(SystemExit):
1574+
LightningCLI(TestModel, run=False)
1575+
1576+
assert "a_func: torch.softmax" in out.getvalue()

0 commit comments

Comments
 (0)