Skip to content

Commit d88d0ce

Browse files
committed
Add support callable protocols for instance factory dependency injection (#755).
1 parent 785af2c commit d88d0ce

File tree

5 files changed

+147
-46
lines changed

5 files changed

+147
-46
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Added
1919
^^^^^
2020
- Support for Python 3.14 (`#753
2121
<https://github.com/omni-us/jsonargparse/pull/753>`__).
22+
- Support callable protocols for instance factory dependency injection (`#758
23+
<https://github.com/omni-us/jsonargparse/pull/758>`__).
2224

2325
Changed
2426
^^^^^^^

jsonargparse/_actions.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,6 @@ def __init__(self, typehint=None, **kwargs):
350350
if typehint is not None:
351351
self._typehint = typehint
352352
else:
353-
self._typehint = kwargs.pop("_typehint")
354353
self.update_init_kwargs(kwargs)
355354
super().__init__(**kwargs)
356355

@@ -363,13 +362,14 @@ def update_init_kwargs(self, kwargs):
363362
is_protocol,
364363
)
365364

366-
typehint = get_unaliased_type(get_optional_arg(self._typehint))
365+
typehint = get_unaliased_type(get_optional_arg(kwargs.pop("_typehint")))
367366
if get_typehint_origin(typehint) is not Union:
368367
assert "nargs" not in kwargs
369368
kwargs["nargs"] = "?"
370-
self._basename = iter_to_set_str(get_subclass_names(self._typehint, callable_return=True))
369+
self._typehint = typehint
370+
self._basename = iter_to_set_str(get_subclass_names(typehint, callable_return=True))
371371
self._baseclasses = get_subclass_types(typehint, callable_return=True)
372-
assert self._baseclasses
372+
assert self._baseclasses and all(isinstance(b, type) for b in self._baseclasses)
373373

374374
self._kind = "subclass of"
375375
if any(is_protocol(b) for b in self._baseclasses):
@@ -391,28 +391,28 @@ def __call__(self, *args, **kwargs):
391391

392392
def print_help(self, call_args):
393393
from ._typehints import (
394-
ActionTypeHint,
395-
get_optional_arg,
396-
get_unaliased_type,
394+
adapt_partial_callable_class,
397395
implements_protocol,
398396
resolve_class_path_by_name,
399397
)
400398

401399
parser, _, value, option_string = call_args
402400
try:
403-
typehint = get_unaliased_type(get_optional_arg(self._typehint))
404401
if self.nargs == "?" and value is None:
405-
val_class = typehint
402+
val_class = self._typehint
406403
else:
407-
val_class = import_object(resolve_class_path_by_name(typehint, value))
404+
val_class = import_object(resolve_class_path_by_name(self._baseclasses, value))
408405
except Exception as ex:
409406
raise TypeError(f"{option_string}: {ex}") from ex
407+
410408
if not any(is_subclass(val_class, b) or implements_protocol(val_class, b) for b in self._baseclasses):
411409
raise TypeError(f'{option_string}: Class "{value}" is not a {self._kind} {self._basename}')
412410
dest = re.sub("\\.help$", "", self.dest)
413411
subparser = type(parser)(description=f"Help for {option_string}={get_import_path(val_class)}")
414-
if ActionTypeHint.is_callable_typehint(typehint) and hasattr(typehint, "__args__"):
415-
self.sub_add_kwargs["skip"] = {max(0, len(typehint.__args__) - 1)}
412+
val = Namespace(class_path=get_import_path(val_class))
413+
_, partial_skip_args = adapt_partial_callable_class(self._typehint, val)
414+
if partial_skip_args:
415+
self.sub_add_kwargs["skip"] = partial_skip_args
416416
subparser.add_class_arguments(val_class, dest, **self.sub_add_kwargs)
417417
subparser._inner_parser = True
418418
remove_actions(subparser, (_HelpAction, _ActionPrintConfig, _ActionConfigLoad))

jsonargparse/_signatures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def _add_signature_arguments(
267267
"""
268268
params = get_signature_parameters(function_or_class, method_name, logger=self.logger)
269269

270-
skip_positionals = [s for s in (skip or []) if isinstance(s, int)]
270+
skip_positionals = [s for s in (skip or []) if isinstance(s, int) and s != 0]
271271
if skip_positionals:
272272
if len(skip_positionals) > 1 or any(p <= 0 for p in skip_positionals):
273273
raise ValueError(f"Unexpected number of positionals to skip: {skip_positionals}")

jsonargparse/_typehints.py

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -271,14 +271,14 @@ def normalize_default(self, default):
271271
default = default.name
272272
elif is_callable_type(self._typehint) and callable(default) and not inspect.isclass(default):
273273
default = get_import_path(default)
274+
elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default):
275+
default = {"class_path": get_import_path(default)}
274276
elif is_subclass_type and not allow_default_instance.get():
275277
from ._parameter_resolvers import UnknownDefault
276278

277279
default_type = type(default)
278280
if not is_subclass(default_type, UnknownDefault) and self.is_subclass_typehint(default_type):
279281
raise ValueError("Subclass types require as default either a dict with class_path or a lazy instance.")
280-
elif ActionTypeHint.is_return_subclass_typehint(self._typehint) and inspect.isclass(default):
281-
default = {"class_path": get_import_path(default)}
282282
return default
283283

284284
@staticmethod
@@ -352,7 +352,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
352352
def is_return_subclass_typehint(typehint):
353353
typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint)))
354354
typehint_origin = get_typehint_origin(typehint)
355-
if typehint_origin in callable_origin_types:
355+
if typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint):
356356
return_type = get_callable_return_type(typehint)
357357
if ActionTypeHint.is_subclass_typehint(return_type):
358358
return True
@@ -626,15 +626,18 @@ def instantiate_classes(self, value):
626626
return value if islist else value[0]
627627

628628
@staticmethod
629-
def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
629+
def get_class_parser(val_class, sub_add_kwargs=None, skip_args=None):
630630
if isinstance(val_class, str):
631631
val_class = import_object(val_class)
632632
kwargs = dict(sub_add_kwargs) if sub_add_kwargs else {}
633633
if skip_args:
634-
kwargs.setdefault("skip", set()).add(skip_args)
634+
kwargs.setdefault("skip", set()).update(skip_args)
635635
if is_subclass_spec(kwargs.get("default")):
636636
kwargs["default"] = kwargs["default"].get("init_args")
637637
parser = parent_parser.get()
638+
from ._core import ArgumentParser
639+
640+
assert isinstance(parser, ArgumentParser)
638641
parser = type(parser)(exit_on_error=False, logger=parser.logger, parser_mode=parser.parser_mode)
639642
remove_actions(parser, (ActionConfigFile, _ActionPrintConfig))
640643
if inspect.isclass(val_class) or inspect.isclass(get_typehint_origin(val_class)):
@@ -658,11 +661,10 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
658661
def extra_help(self):
659662
extra = ""
660663
typehint = get_optional_arg(self._typehint)
661-
if self.is_subclass_typehint(typehint, all_subtypes=False) or get_typehint_origin(
662-
typehint
663-
) in callable_origin_types.union({Type, type}):
664-
if self.is_callable_typehint(typehint) and getattr(typehint, "__args__", None):
665-
typehint = get_callable_return_type(get_optional_arg(typehint))
664+
typehint = get_callable_return_type(typehint) or typehint
665+
if get_typehint_origin(typehint) is type:
666+
typehint = typehint.__args__[0]
667+
if self.is_subclass_typehint(typehint, all_subtypes=False):
666668
class_paths = get_all_subclass_paths(typehint)
667669
if class_paths:
668670
extra = ", known subclasses: " + ", ".join(class_paths)
@@ -967,11 +969,15 @@ def adapt_typehints(
967969
val = adapt_typehints(val, subtypehints[0], **adapt_kwargs)
968970

969971
# Callable
970-
elif typehint_origin in callable_origin_types or typehint in callable_origin_types:
972+
elif (
973+
typehint_origin in callable_origin_types
974+
or typehint in callable_origin_types
975+
or is_instance_factory_protocol(typehint, logger)
976+
):
971977
if serialize:
972978
if is_subclass_spec(val):
973-
val, _, num_partial_args = adapt_partial_callable_class(typehint, val)
974-
val = adapt_class_type(val, True, False, sub_add_kwargs, skip_args=num_partial_args)
979+
val, partial_skip_args = adapt_partial_callable_class(typehint, val)
980+
val = adapt_class_type(val, True, False, sub_add_kwargs, partial_skip_args=partial_skip_args)
975981
else:
976982
val = object_path_serializer(val)
977983
else:
@@ -1000,21 +1006,21 @@ def adapt_typehints(
10001006
raise ImportError(
10011007
f"Dict must include a class_path and optionally init_args, but got {val_input}"
10021008
)
1003-
val, partial_classes, num_partial_args = adapt_partial_callable_class(typehint, val)
1009+
val, partial_skip_args = adapt_partial_callable_class(typehint, val)
10041010
val_class = import_object(val["class_path"])
1005-
if inspect.isclass(val_class) and not (partial_classes or callable_instances(val_class)):
1011+
if inspect.isclass(val_class) and not (partial_skip_args or callable_instances(val_class)):
1012+
base_type = get_callable_return_type(typehint) or typehint
10061013
raise ImportError(
10071014
f"Expected '{val['class_path']}' to be a class that instantiates into callable "
1008-
f"or a subclass of {partial_classes}."
1015+
f"or a subclass of {base_type}."
10091016
)
10101017
val["class_path"] = get_import_path(val_class)
10111018
val = adapt_class_type(
10121019
val,
10131020
False,
10141021
instantiate_classes,
10151022
sub_add_kwargs,
1016-
skip_args=num_partial_args,
1017-
partial_classes=partial_classes,
1023+
partial_skip_args=partial_skip_args,
10181024
prev_val=prev_val,
10191025
)
10201026
except (ImportError, AttributeError, ArgumentError) as ex:
@@ -1172,6 +1178,15 @@ def is_instance_or_supports_protocol(value, class_type):
11721178
return isinstance(value, class_type)
11731179

11741180

1181+
def is_instance_factory_protocol(class_type, logger=None):
1182+
if not is_protocol(class_type) or not callable_instances(class_type):
1183+
return False
1184+
from ._postponed_annotations import get_return_type
1185+
1186+
return_type = get_return_type(class_type.__call__, logger)
1187+
return ActionTypeHint.is_subclass_typehint(return_type)
1188+
1189+
11751190
def is_subclass_spec(val):
11761191
is_class = isinstance(val, (dict, Namespace)) and "class_path" in val
11771192
if is_class:
@@ -1214,9 +1229,14 @@ def subclass_spec_as_namespace(val, prev_val=None):
12141229

12151230
def get_callable_return_type(typehint):
12161231
return_type = None
1217-
args = getattr(typehint, "__args__", None)
1218-
if isinstance(args, tuple) and len(args) > 0:
1219-
return_type = args[-1]
1232+
if is_instance_factory_protocol(typehint):
1233+
from ._postponed_annotations import get_return_type
1234+
1235+
return_type = get_return_type(typehint.__call__)
1236+
elif get_typehint_origin(typehint) in callable_origin_types:
1237+
args = getattr(typehint, "__args__", None)
1238+
if isinstance(args, tuple) and len(args) > 0:
1239+
return_type = args[-1]
12201240
return return_type
12211241

12221242

@@ -1238,7 +1258,7 @@ def yield_subclass_types(typehint, also_lists=False, callable_return=False):
12381258
return
12391259
typehint = get_unaliased_type(get_optional_arg(get_unaliased_type(typehint)))
12401260
typehint_origin = get_typehint_origin(typehint)
1241-
if callable_return and typehint_origin in callable_origin_types:
1261+
if callable_return and (typehint_origin in callable_origin_types or is_instance_factory_protocol(typehint)):
12421262
return_type = get_callable_return_type(typehint)
12431263
if return_type:
12441264
k = {"also_lists": also_lists, "callable_return": callable_return}
@@ -1261,18 +1281,26 @@ def get_subclass_names(typehint, callable_return=False):
12611281

12621282

12631283
def adapt_partial_callable_class(callable_type, subclass_spec):
1264-
partial_classes = False
1265-
num_partial_args = 0
1284+
partial_skip_args = None
12661285
return_type = get_callable_return_type(callable_type)
12671286
if return_type:
12681287
subclass_types = get_subclass_types(return_type)
12691288
class_type = import_object(resolve_class_path_by_name(return_type, subclass_spec.class_path))
12701289
if subclass_types and is_subclass(class_type, subclass_types):
12711290
subclass_spec = subclass_spec.clone()
12721291
subclass_spec["class_path"] = get_import_path(class_type)
1273-
partial_classes = True
1274-
num_partial_args = len(callable_type.__args__) - 1
1275-
return subclass_spec, partial_classes, num_partial_args
1292+
if is_protocol(callable_type):
1293+
from ._parameter_resolvers import get_signature_parameters
1294+
1295+
params = get_signature_parameters(callable_type, "__call__")
1296+
partial_skip_args = set()
1297+
positionals = [p for p in params if "POSITIONAL_ONLY" in str(p.kind)]
1298+
if positionals:
1299+
partial_skip_args.add(len(positionals))
1300+
partial_skip_args.update(p.name for p in params if "POSITIONAL_ONLY" not in str(p.kind))
1301+
else:
1302+
partial_skip_args = {len(callable_type.__args__) - 1}
1303+
return subclass_spec, partial_skip_args
12761304

12771305

12781306
def get_all_subclass_paths(cls: Type) -> List[str]:
@@ -1318,9 +1346,15 @@ def add_subclasses(cl):
13181346
return subclass_list
13191347

13201348

1321-
def resolve_class_path_by_name(cls: Type, name: str) -> str:
1349+
def resolve_class_path_by_name(cls: Union[Type, Tuple[Type]], name: str) -> str:
13221350
class_path = name
13231351
if "." not in class_path:
1352+
if isinstance(cls, tuple):
1353+
for cls_n in cls:
1354+
class_path = resolve_class_path_by_name(cls_n, name)
1355+
if "." in class_path:
1356+
break
1357+
return class_path
13241358
subclass_dict = defaultdict(list)
13251359
for subclass in get_all_subclass_paths(cls):
13261360
subclass_name = subclass.rsplit(".", 1)[1]
@@ -1376,13 +1410,11 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
13761410
)
13771411

13781412

1379-
def adapt_class_type(
1380-
value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None, skip_args=0, partial_classes=False
1381-
):
1413+
def adapt_class_type(value, serialize, instantiate_classes, sub_add_kwargs, prev_val=None, partial_skip_args=None):
13821414
prev_val = subclass_spec_as_namespace(prev_val)
13831415
value = subclass_spec_as_namespace(value)
13841416
val_class = import_object(value.class_path)
1385-
parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs, skip_args=skip_args)
1417+
parser = ActionTypeHint.get_class_parser(val_class, sub_add_kwargs, skip_args=partial_skip_args)
13861418

13871419
# No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream.
13881420
for target in sub_add_kwargs.get("linked_targets", []):
@@ -1415,7 +1447,7 @@ def adapt_class_type(
14151447

14161448
instantiator_fn = get_class_instantiator()
14171449

1418-
if partial_classes:
1450+
if partial_skip_args:
14191451
return partial(
14201452
instantiator_fn,
14211453
val_class,

jsonargparse_tests/test_typehints.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Literal,
2222
Mapping,
2323
Optional,
24+
Protocol,
2425
Sequence,
2526
Set,
2627
Tuple,
@@ -1058,6 +1059,72 @@ def test_callable_args_return_type_class(parser, subtests):
10581059
assert "--optimizer.params" not in help_str
10591060

10601061

1062+
class OptimizerFactory(Protocol):
1063+
def __call__(self, params: List[float]) -> Optimizer: ...
1064+
1065+
1066+
class DifferentParamsOrder(Optimizer):
1067+
def __init__(self, lr: float, params: List[float], momentum: float = 0.0):
1068+
super().__init__(lr=lr, params=params, momentum=momentum)
1069+
1070+
1071+
def test_callable_protocol_instance_factory(parser, subtests):
1072+
parser.add_argument("--optimizer", type=OptimizerFactory, default=SGD)
1073+
1074+
with subtests.test("default"):
1075+
cfg = parser.get_defaults()
1076+
assert cfg.optimizer.class_path == f"{__name__}.SGD"
1077+
init = parser.instantiate_classes(cfg)
1078+
optimizer = init.optimizer(params=[1, 2])
1079+
assert isinstance(optimizer, SGD)
1080+
assert optimizer.params == [1, 2]
1081+
assert optimizer.lr == 1e-3
1082+
assert optimizer.momentum == 0.0
1083+
1084+
with subtests.test("parse dict"):
1085+
value = {
1086+
"class_path": "Adam",
1087+
"init_args": {
1088+
"lr": 0.01,
1089+
"momentum": 0.9,
1090+
},
1091+
}
1092+
cfg = parser.parse_args([f"--optimizer={json.dumps(value)}"])
1093+
init = parser.instantiate_classes(cfg)
1094+
optimizer = init.optimizer(params=[3, 2, 1])
1095+
assert isinstance(optimizer, Adam)
1096+
assert optimizer.params == [3, 2, 1]
1097+
assert optimizer.lr == 0.01
1098+
assert optimizer.momentum == 0.9
1099+
1100+
with subtests.test("params order"):
1101+
value = {
1102+
"class_path": "DifferentParamsOrder",
1103+
"init_args": {
1104+
"lr": 0.1,
1105+
"momentum": 0.8,
1106+
},
1107+
}
1108+
cfg = parser.parse_args([f"--optimizer={json.dumps(value)}"])
1109+
init = parser.instantiate_classes(cfg)
1110+
optimizer = init.optimizer(params=[3, 2])
1111+
assert isinstance(optimizer, DifferentParamsOrder)
1112+
assert optimizer.params == [3, 2]
1113+
assert optimizer.lr == 0.1
1114+
assert optimizer.momentum == 0.8
1115+
dump = parser.dump(cfg)
1116+
assert json_or_yaml_load(dump) == cfg.as_dict()
1117+
1118+
with subtests.test("help"):
1119+
help_str = get_parser_help(parser)
1120+
assert "--optimizer.help" in help_str
1121+
assert "Show the help for the given subclass or implementer of protocol {Optimizer,OptimizerFactory" in help_str
1122+
help_str = get_parse_args_stdout(parser, [f"--optimizer.help={__name__}.DifferentParamsOrder"])
1123+
assert f"Help for --optimizer.help={__name__}.DifferentParamsOrder" in help_str
1124+
assert "--optimizer.lr" in help_str
1125+
assert "--optimizer.params" not in help_str
1126+
1127+
10611128
def test_optional_callable_return_type_help(parser):
10621129
parser.add_argument("--optimizer", type=Optional[Callable[[List[float]], Optimizer]])
10631130
help_str = get_parser_help(parser)

0 commit comments

Comments
 (0)