Skip to content

Commit 48b944c

Browse files
committed
Add patch to test
1 parent 68677d3 commit 48b944c

File tree

3 files changed

+18
-15
lines changed

3 files changed

+18
-15
lines changed

.actions/assistant.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -483,16 +483,9 @@ def convert_version2nightly(ver_file: str = "src/version.info") -> None:
483483

484484

485485
if __name__ == "__main__":
486-
import sys
487-
488486
import jsonargparse
489-
from jsonargparse import ArgumentParser
490-
491-
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
492-
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
493-
return namespace, args
487+
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8
494488

495-
if sys.version_info >= (3, 12, 8):
496-
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
489+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
497490

498491
jsonargparse.CLI(AssistantCLI, as_positional=False)

src/lightning/pytorch/cli.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,18 @@
3737

3838
_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.27.7")
3939

40+
41+
def patch_jsonargparse_python_3_12_8():
42+
if sys.version_info < (3, 12, 8):
43+
return
44+
45+
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
46+
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
47+
return namespace, args
48+
49+
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
50+
51+
4052
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
4153
import docstring_parser
4254
from jsonargparse import (
@@ -48,12 +60,7 @@
4860
set_config_read_mode,
4961
)
5062

51-
def _parse_known_args_patch(self: ArgumentParser, args: Any = None, namespace: Any = None) -> tuple[Any, Any]:
52-
namespace, args = super(ArgumentParser, self)._parse_known_args(args, namespace, intermixed=False) # type: ignore
53-
return namespace, args
54-
55-
if sys.version_info >= (3, 12, 8):
56-
setattr(ArgumentParser, "_parse_known_args", _parse_known_args_patch)
63+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
5764

5865
register_unresolvable_import_paths(torch) # Required until fix https://github.com/pytorch/pytorch/issues/74483
5966
set_config_read_mode(fsspec_enabled=True)

tests/parity_fabric/test_parity_ddp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,5 +162,8 @@ def run_parity_test(accelerator: str = "cpu", devices: int = 2, tolerance: float
162162

163163
if __name__ == "__main__":
164164
from jsonargparse import CLI
165+
from lightning.pytorch.cli import patch_jsonargparse_python_3_12_8
166+
167+
patch_jsonargparse_python_3_12_8() # Required until fix https://github.com/omni-us/jsonargparse/issues/641
165168

166169
CLI(run_parity_test)

0 commit comments

Comments
 (0)