Skip to content

Commit 2c95877

Browse files
akihironittapre-commit-ci[bot]
authored andcommitted
Fix add_argparse_args raising TypeError with Python 3.6 (#9554)
* Add test * Accept TypeError for arg_type.__args__ being None Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 83925f1 commit 2c95877

File tree

3 files changed

+21
-2
lines changed

3 files changed

+21
-2
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111
- Added PL_RECONCILE_PROCESS environment variable to enable process reconciliation regardless of cluster environment settings (#9389)
1212

1313

14+
- Fixed `add_argparse_args` raising `TypeError` when args are typed as `typing.Generic` in Python 3.6 ([#9554](https://github.com/PyTorchLightning/pytorch-lightning/pull/9554))
15+
16+
1417
## [1.4.7] - 2021-09-14
1518

1619
- Fixed logging of nan parameters ([#9364](https://github.com/PyTorchLightning/pytorch-lightning/pull/9364))

pytorch_lightning/utilities/argparse.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
126126
arg_default = cls_default_params[arg].default
127127
try:
128128
arg_types = tuple(arg_type.__args__)
129-
except AttributeError:
129+
except (AttributeError, TypeError):
130130
arg_types = (arg_type,)
131131

132132
name_type_default.append((arg, arg_types, arg_default))

tests/utilities/test_argparse.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
from argparse import ArgumentParser, Namespace
3-
from typing import List
3+
from typing import Generic, List, TypeVar
44
from unittest.mock import MagicMock
55

66
import pytest
@@ -136,6 +136,16 @@ def __init__(self, my_parameter: int = 0):
136136
pass
137137

138138

139+
class AddArgparseArgsExampleClassGeneric:
140+
T = TypeVar("T")
141+
142+
class SomeClass(Generic[T]):
143+
pass
144+
145+
def __init__(self, invalid_class: SomeClass):
146+
pass
147+
148+
139149
def extract_help_text(parser):
140150
help_str_buffer = io.StringIO()
141151
parser.print_help(file=help_str_buffer)
@@ -207,6 +217,12 @@ def test_add_argparse_args_no_argument_group():
207217
assert args.my_parameter == 2
208218

209219

220+
def test_add_argparse_args_invalid():
221+
"""Test that `add_argparse_args` doesn't raise `TypeError` when a class has args typed as `typing.Generic` in
222+
Python 3.6."""
223+
add_argparse_args(AddArgparseArgsExampleClassGeneric, ArgumentParser())
224+
225+
210226
def test_gpus_allowed_type():
211227
assert _gpus_allowed_type("1,2") == "1,2"
212228
assert _gpus_allowed_type("1") == 1

0 commit comments

Comments
 (0)