Skip to content

Commit 7781048

Browse files
committed
Fix help for Protocol types not working correctly.
1 parent cd090e3 commit 7781048

File tree

4 files changed

+51
-9
lines changed

4 files changed

+51
-9
lines changed

CHANGELOG.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@ The semantic versioning only considers the public API as described in
1212
paths are considered internals and can change in minor and patch releases.
1313

1414

15+
v4.35.1 (2024-12-??)
16+
--------------------
17+
18+
Fixed
19+
^^^^^
20+
- Help for ``Protocol`` types not working correctly (`#???
21+
<https://github.com/omni-us/jsonargparse/pull/???>`__).
22+
23+
1524
v4.35.0 (2024-12-16)
1625
--------------------
1726

jsonargparse/_actions.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -350,18 +350,31 @@ def __init__(self, typehint=None, **kwargs):
350350
super().__init__(**kwargs)
351351

352352
def update_init_kwargs(self, kwargs):
353-
from ._typehints import get_optional_arg, get_subclass_names, get_unaliased_type
353+
from ._typehints import (
354+
get_optional_arg,
355+
get_subclass_names,
356+
get_subclass_types,
357+
get_unaliased_type,
358+
is_protocol,
359+
)
354360

355361
typehint = get_unaliased_type(get_optional_arg(self._typehint))
356362
if get_typehint_origin(typehint) is not Union:
357363
assert "nargs" not in kwargs
358364
kwargs["nargs"] = "?"
359365
self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True))
366+
self._baseclasses = get_subclass_types(typehint, callable_return=True)
367+
assert self._baseclasses
368+
369+
self._kind = "subclass of"
370+
if any(is_protocol(b) for b in self._baseclasses):
371+
self._kind = "subclass or implementer of protocol"
372+
360373
kwargs.update(
361374
{
362375
"metavar": "CLASS_PATH_OR_NAME",
363376
"default": SUPPRESS,
364-
"help": f"Show the help for the given subclass of {self._basename} and exit.",
377+
"help": f"Show the help for the given {self._kind} {self._basename} and exit.",
365378
}
366379
)
367380

@@ -375,23 +388,22 @@ def print_help(self, call_args):
375388
from ._typehints import (
376389
ActionTypeHint,
377390
get_optional_arg,
378-
get_subclass_types,
379391
get_unaliased_type,
392+
implements_protocol,
380393
resolve_class_path_by_name,
381394
)
382395

383396
parser, _, value, option_string = call_args
384397
try:
385398
typehint = get_unaliased_type(get_optional_arg(self._typehint))
386-
baseclasses = get_subclass_types(typehint, callable_return=True)
387399
if self.nargs == "?" and value is None:
388400
val_class = typehint
389401
else:
390402
val_class = import_object(resolve_class_path_by_name(typehint, value))
391403
except Exception as ex:
392404
raise TypeError(f"{option_string}: {ex}") from ex
393-
if not any(is_subclass(val_class, b) for b in baseclasses):
394-
raise TypeError(f'{option_string}: Class "{value}" is not a subclass of {self._basename}')
405+
if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses):
406+
raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}')
395407
dest = re.sub("\\.help$", "", self.dest)
396408
subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}")
397409
if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"):

jsonargparse/_typehints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1103,7 +1103,7 @@ def implements_protocol(value, protocol) -> bool:
11031103
from jsonargparse._parameter_resolvers import get_signature_parameters
11041104
from jsonargparse._postponed_annotations import get_return_type
11051105

1106-
if not inspect.isclass(value) or value is object:
1106+
if not inspect.isclass(value) or value is object or not is_protocol(protocol):
11071107
return False
11081108
members = 0
11091109
for name, _ in inspect.getmembers(protocol, predicate=inspect.isfunction):

jsonargparse_tests/test_subclasses.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,6 +1443,14 @@ def predict(self, items: List[float]) -> List[float]:
14431443
return items
14441444

14451445

1446+
class SubclassImplementsInterface(Interface):
1447+
def __init__(self, max_items: int):
1448+
self.max_items = max_items
1449+
1450+
def predict(self, items: List[float]) -> List[float]:
1451+
return items
1452+
1453+
14461454
class NotImplementsInterface1:
14471455
def predict(self, items: str) -> List[float]:
14481456
return []
@@ -1462,6 +1470,7 @@ def predict(self, items: List[float]) -> None:
14621470
"expected, value",
14631471
[
14641472
(True, ImplementsInterface),
1473+
(True, SubclassImplementsInterface),
14651474
(False, ImplementsInterface(1)),
14661475
(False, NotImplementsInterface1),
14671476
(False, NotImplementsInterface2),
@@ -1488,14 +1497,22 @@ def test_is_instance_or_supports_protocol(expected, value):
14881497

14891498
def test_parse_implements_protocol(parser):
14901499
parser.add_argument("--cls", type=Interface)
1491-
assert "known subclasses:" not in get_parser_help(parser)
14921500
cfg = parser.parse_args([f"--cls={__name__}.ImplementsInterface", "--cls.batch_size=5"])
14931501
assert cfg.cls.class_path == f"{__name__}.ImplementsInterface"
14941502
assert cfg.cls.init_args == Namespace(batch_size=5)
14951503
init = parser.instantiate_classes(cfg)
14961504
assert isinstance(init.cls, ImplementsInterface)
14971505
assert init.cls.batch_size == 5
14981506
assert init.cls.predict([1.0, 2.0]) == [1.0, 2.0]
1507+
1508+
help_str = get_parser_help(parser)
1509+
assert "known subclasses:" in help_str
1510+
assert f"{__name__}.SubclassImplementsInterface" in help_str
1511+
help_str = get_parse_args_stdout(parser, [f"--cls.help={__name__}.SubclassImplementsInterface"])
1512+
assert "--cls.max_items" in help_str
1513+
with pytest.raises(ArgumentError, match="not a subclass or implementer of protocol"):
1514+
parser.parse_args([f"--cls.help={__name__}.NotImplementsInterface1"])
1515+
14991516
with pytest.raises(ArgumentError, match="is a protocol"):
15001517
parser.parse_args([f"--cls={__name__}.Interface"])
15011518
with pytest.raises(ArgumentError, match="does not implement protocol"):
@@ -1551,13 +1568,17 @@ def test_implements_callable_protocol(expected, value):
15511568

15521569
def test_parse_implements_callable_protocol(parser):
15531570
parser.add_argument("--cls", type=CallableInterface)
1554-
assert "known subclasses:" not in get_parser_help(parser)
15551571
cfg = parser.parse_args([f"--cls={__name__}.ImplementsCallableInterface", "--cls.batch_size=7"])
15561572
assert cfg.cls.class_path == f"{__name__}.ImplementsCallableInterface"
15571573
assert cfg.cls.init_args == Namespace(batch_size=7)
15581574
init = parser.instantiate_classes(cfg)
15591575
assert isinstance(init.cls, ImplementsCallableInterface)
15601576
assert init.cls([1.0, 2.0]) == [1.0, 2.0]
1577+
1578+
assert "known subclasses:" not in get_parser_help(parser)
1579+
help_str = get_parse_args_stdout(parser, [f"--cls.help={__name__}.ImplementsCallableInterface"])
1580+
assert "--cls.batch_size" in help_str
1581+
15611582
with pytest.raises(ArgumentError, match="is a protocol"):
15621583
parser.parse_args([f"--cls={__name__}.CallableInterface"])
15631584
with pytest.raises(ArgumentError, match="does not implement protocol"):

0 commit comments

Comments
 (0)