@@ -271,14 +271,14 @@ def normalize_default(self, default):
271271 default = default .name
272272 elif is_callable_type (self ._typehint ) and callable (default ) and not inspect .isclass (default ):
273273 default = get_import_path (default )
274+ elif ActionTypeHint .is_return_subclass_typehint (self ._typehint ) and inspect .isclass (default ):
275+ default = {"class_path" : get_import_path (default )}
274276 elif is_subclass_type and not allow_default_instance .get ():
275277 from ._parameter_resolvers import UnknownDefault
276278
277279 default_type = type (default )
278280 if not is_subclass (default_type , UnknownDefault ) and self .is_subclass_typehint (default_type ):
279281 raise ValueError ("Subclass types require as default either a dict with class_path or a lazy instance." )
280- elif ActionTypeHint .is_return_subclass_typehint (self ._typehint ) and inspect .isclass (default ):
281- default = {"class_path" : get_import_path (default )}
282282 return default
283283
284284 @staticmethod
@@ -352,7 +352,7 @@ def is_subclass_typehint(typehint, all_subtypes=True, also_lists=False):
352352 def is_return_subclass_typehint (typehint ):
353353 typehint = get_unaliased_type (get_optional_arg (get_unaliased_type (typehint )))
354354 typehint_origin = get_typehint_origin (typehint )
355- if typehint_origin in callable_origin_types :
355+ if typehint_origin in callable_origin_types or is_instance_factory_protocol ( typehint ) :
356356 return_type = get_callable_return_type (typehint )
357357 if ActionTypeHint .is_subclass_typehint (return_type ):
358358 return True
@@ -626,15 +626,18 @@ def instantiate_classes(self, value):
626626 return value if islist else value [0 ]
627627
628628 @staticmethod
629- def get_class_parser (val_class , sub_add_kwargs = None , skip_args = 0 ):
629+ def get_class_parser (val_class , sub_add_kwargs = None , skip_args = None ):
630630 if isinstance (val_class , str ):
631631 val_class = import_object (val_class )
632632 kwargs = dict (sub_add_kwargs ) if sub_add_kwargs else {}
633633 if skip_args :
634- kwargs .setdefault ("skip" , set ()).add (skip_args )
634+ kwargs .setdefault ("skip" , set ()).update (skip_args )
635635 if is_subclass_spec (kwargs .get ("default" )):
636636 kwargs ["default" ] = kwargs ["default" ].get ("init_args" )
637637 parser = parent_parser .get ()
638+ from ._core import ArgumentParser
639+
640+ assert isinstance (parser , ArgumentParser )
638641 parser = type (parser )(exit_on_error = False , logger = parser .logger , parser_mode = parser .parser_mode )
639642 remove_actions (parser , (ActionConfigFile , _ActionPrintConfig ))
640643 if inspect .isclass (val_class ) or inspect .isclass (get_typehint_origin (val_class )):
@@ -658,11 +661,10 @@ def get_class_parser(val_class, sub_add_kwargs=None, skip_args=0):
658661 def extra_help (self ):
659662 extra = ""
660663 typehint = get_optional_arg (self ._typehint )
661- if self .is_subclass_typehint (typehint , all_subtypes = False ) or get_typehint_origin (
662- typehint
663- ) in callable_origin_types .union ({Type , type }):
664- if self .is_callable_typehint (typehint ) and getattr (typehint , "__args__" , None ):
665- typehint = get_callable_return_type (get_optional_arg (typehint ))
664+ typehint = get_callable_return_type (typehint ) or typehint
665+ if get_typehint_origin (typehint ) is type :
666+ typehint = typehint .__args__ [0 ]
667+ if self .is_subclass_typehint (typehint , all_subtypes = False ):
666668 class_paths = get_all_subclass_paths (typehint )
667669 if class_paths :
668670 extra = ", known subclasses: " + ", " .join (class_paths )
@@ -967,11 +969,15 @@ def adapt_typehints(
967969 val = adapt_typehints (val , subtypehints [0 ], ** adapt_kwargs )
968970
969971 # Callable
970- elif typehint_origin in callable_origin_types or typehint in callable_origin_types :
972+ elif (
973+ typehint_origin in callable_origin_types
974+ or typehint in callable_origin_types
975+ or is_instance_factory_protocol (typehint , logger )
976+ ):
971977 if serialize :
972978 if is_subclass_spec (val ):
973- val , _ , num_partial_args = adapt_partial_callable_class (typehint , val )
974- val = adapt_class_type (val , True , False , sub_add_kwargs , skip_args = num_partial_args )
979+ val , partial_skip_args = adapt_partial_callable_class (typehint , val )
980+ val = adapt_class_type (val , True , False , sub_add_kwargs , partial_skip_args = partial_skip_args )
975981 else :
976982 val = object_path_serializer (val )
977983 else :
@@ -1000,21 +1006,21 @@ def adapt_typehints(
10001006 raise ImportError (
10011007 f"Dict must include a class_path and optionally init_args, but got { val_input } "
10021008 )
1003- val , partial_classes , num_partial_args = adapt_partial_callable_class (typehint , val )
1009+ val , partial_skip_args = adapt_partial_callable_class (typehint , val )
10041010 val_class = import_object (val ["class_path" ])
1005- if inspect .isclass (val_class ) and not (partial_classes or callable_instances (val_class )):
1011+ if inspect .isclass (val_class ) and not (partial_skip_args or callable_instances (val_class )):
1012+ base_type = get_callable_return_type (typehint ) or typehint
10061013 raise ImportError (
10071014 f"Expected '{ val ['class_path' ]} ' to be a class that instantiates into callable "
1008- f"or a subclass of { partial_classes } ."
1015+ f"or a subclass of { base_type } ."
10091016 )
10101017 val ["class_path" ] = get_import_path (val_class )
10111018 val = adapt_class_type (
10121019 val ,
10131020 False ,
10141021 instantiate_classes ,
10151022 sub_add_kwargs ,
1016- skip_args = num_partial_args ,
1017- partial_classes = partial_classes ,
1023+ partial_skip_args = partial_skip_args ,
10181024 prev_val = prev_val ,
10191025 )
10201026 except (ImportError , AttributeError , ArgumentError ) as ex :
@@ -1172,6 +1178,15 @@ def is_instance_or_supports_protocol(value, class_type):
11721178 return isinstance (value , class_type )
11731179
11741180
1181+ def is_instance_factory_protocol (class_type , logger = None ):
1182+ if not is_protocol (class_type ) or not callable_instances (class_type ):
1183+ return False
1184+ from ._postponed_annotations import get_return_type
1185+
1186+ return_type = get_return_type (class_type .__call__ , logger )
1187+ return ActionTypeHint .is_subclass_typehint (return_type )
1188+
1189+
11751190def is_subclass_spec (val ):
11761191 is_class = isinstance (val , (dict , Namespace )) and "class_path" in val
11771192 if is_class :
@@ -1214,9 +1229,14 @@ def subclass_spec_as_namespace(val, prev_val=None):
12141229
12151230def get_callable_return_type (typehint ):
12161231 return_type = None
1217- args = getattr (typehint , "__args__" , None )
1218- if isinstance (args , tuple ) and len (args ) > 0 :
1219- return_type = args [- 1 ]
1232+ if is_instance_factory_protocol (typehint ):
1233+ from ._postponed_annotations import get_return_type
1234+
1235+ return_type = get_return_type (typehint .__call__ )
1236+ elif get_typehint_origin (typehint ) in callable_origin_types :
1237+ args = getattr (typehint , "__args__" , None )
1238+ if isinstance (args , tuple ) and len (args ) > 0 :
1239+ return_type = args [- 1 ]
12201240 return return_type
12211241
12221242
@@ -1238,7 +1258,7 @@ def yield_subclass_types(typehint, also_lists=False, callable_return=False):
12381258 return
12391259 typehint = get_unaliased_type (get_optional_arg (get_unaliased_type (typehint )))
12401260 typehint_origin = get_typehint_origin (typehint )
1241- if callable_return and typehint_origin in callable_origin_types :
1261+ if callable_return and ( typehint_origin in callable_origin_types or is_instance_factory_protocol ( typehint )) :
12421262 return_type = get_callable_return_type (typehint )
12431263 if return_type :
12441264 k = {"also_lists" : also_lists , "callable_return" : callable_return }
@@ -1261,18 +1281,26 @@ def get_subclass_names(typehint, callable_return=False):
12611281
12621282
12631283def adapt_partial_callable_class (callable_type , subclass_spec ):
1264- partial_classes = False
1265- num_partial_args = 0
1284+ partial_skip_args = None
12661285 return_type = get_callable_return_type (callable_type )
12671286 if return_type :
12681287 subclass_types = get_subclass_types (return_type )
12691288 class_type = import_object (resolve_class_path_by_name (return_type , subclass_spec .class_path ))
12701289 if subclass_types and is_subclass (class_type , subclass_types ):
12711290 subclass_spec = subclass_spec .clone ()
12721291 subclass_spec ["class_path" ] = get_import_path (class_type )
1273- partial_classes = True
1274- num_partial_args = len (callable_type .__args__ ) - 1
1275- return subclass_spec , partial_classes , num_partial_args
1292+ if is_protocol (callable_type ):
1293+ from ._parameter_resolvers import get_signature_parameters
1294+
1295+ params = get_signature_parameters (callable_type , "__call__" )
1296+ partial_skip_args = set ()
1297+ positionals = [p for p in params if "POSITIONAL_ONLY" in str (p .kind )]
1298+ if positionals :
1299+ partial_skip_args .add (len (positionals ))
1300+ partial_skip_args .update (p .name for p in params if "POSITIONAL_ONLY" not in str (p .kind ))
1301+ else :
1302+ partial_skip_args = {len (callable_type .__args__ ) - 1 }
1303+ return subclass_spec , partial_skip_args
12761304
12771305
12781306def get_all_subclass_paths (cls : Type ) -> List [str ]:
@@ -1318,9 +1346,15 @@ def add_subclasses(cl):
13181346 return subclass_list
13191347
13201348
1321- def resolve_class_path_by_name (cls : Type , name : str ) -> str :
1349+ def resolve_class_path_by_name (cls : Union [ Type , Tuple [ Type ]] , name : str ) -> str :
13221350 class_path = name
13231351 if "." not in class_path :
1352+ if isinstance (cls , tuple ):
1353+ for cls_n in cls :
1354+ class_path = resolve_class_path_by_name (cls_n , name )
1355+ if "." in class_path :
1356+ break
1357+ return class_path
13241358 subclass_dict = defaultdict (list )
13251359 for subclass in get_all_subclass_paths (cls ):
13261360 subclass_name = subclass .rsplit ("." , 1 )[1 ]
@@ -1376,13 +1410,11 @@ def discard_init_args_on_class_path_change(parser_or_action, prev_val, value):
13761410 )
13771411
13781412
1379- def adapt_class_type (
1380- value , serialize , instantiate_classes , sub_add_kwargs , prev_val = None , skip_args = 0 , partial_classes = False
1381- ):
1413+ def adapt_class_type (value , serialize , instantiate_classes , sub_add_kwargs , prev_val = None , partial_skip_args = None ):
13821414 prev_val = subclass_spec_as_namespace (prev_val )
13831415 value = subclass_spec_as_namespace (value )
13841416 val_class = import_object (value .class_path )
1385- parser = ActionTypeHint .get_class_parser (val_class , sub_add_kwargs , skip_args = skip_args )
1417+ parser = ActionTypeHint .get_class_parser (val_class , sub_add_kwargs , skip_args = partial_skip_args )
13861418
13871419 # No need to re-create the linked arg but just "inform" the corresponding parser actions that it exists upstream.
13881420 for target in sub_add_kwargs .get ("linked_targets" , []):
@@ -1415,7 +1447,7 @@ def adapt_class_type(
14151447
14161448 instantiator_fn = get_class_instantiator ()
14171449
1418- if partial_classes :
1450+ if partial_skip_args :
14191451 return partial (
14201452 instantiator_fn ,
14211453 val_class ,
0 commit comments