@@ -44,16 +44,18 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
4444
4545
4646def _is_supported_type (field_type : type ) -> bool :
47- if not isinstance (field_type , type ):
48- return False
4947 if get_origin (field_type ) is Optional :
5048 field_type = get_args (field_type )[0 ]
5149 elif get_origin (field_type ) is Union :
5250 non_none_types = [t for t in get_args (field_type ) if t is not type (None )]
51+ if all (_is_supported_type (t ) for t in non_none_types ):
52+ return True
5353 if len (non_none_types ) == 1 :
5454 field_type = non_none_types [0 ]
5555 elif get_origin (field_type ) is Literal :
5656 return all (isinstance (arg , (str , int , float , bool , Enum )) for arg in get_args (field_type ))
57+ if not isinstance (field_type , type ):
58+ return False
5759 return field_type in (str , int , float , bool ) or issubclass (field_type , Enum )
5860
5961
@@ -100,6 +102,8 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
100102 # Default value, promote PydanticUndefined to None
101103 if field .default is PydanticUndefined :
102104 default_value = None
105+ elif field_instance :
106+ default_value = field_instance
103107 else :
104108 default_value = field .default
105109
@@ -167,15 +171,36 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
167171 if get_args (field_type ) and not _is_supported_type (get_args (field_type )[0 ]):
168172 # If theres already something here, we can procede by adding the command with a positional indicator
169173 if field_instance :
170- ########################
171- # MARK: List[BaseModel]
172174 for i , value in enumerate (field_instance ):
173- _recurse_add_fields (parser , value , prefix = f"{ field_name } .{ i } ." )
175+ if isinstance (value , BaseModel ):
176+ ########################
177+ # MARK: List[BaseModel]
178+ _recurse_add_fields (parser , value , prefix = f"{ field_name } .{ i } ." )
179+ continue
180+ else :
181+ ########################
182+ # MARK: List[str|int|float|bool]
183+ _add_argument (
184+ parser = parser ,
185+ name = f"{ arg_name } .{ i } " ,
186+ arg_type = type (value ),
187+ default_value = value ,
188+ )
174189 continue
175190 # If there's nothing here, we don't know how to address them
176191 # TODO: we could just prefill e.g. --field.0, --field.1 up to some limit
177192 _log .warning (f"Only lists of str, int, float, or bool are supported - field `{ field_name } ` got { get_args (field_type )[0 ]} " )
178193 continue
194+ if field_instance :
195+ for i , value in enumerate (field_instance ):
196+ ########################
197+ # MARK: List[str|int|float|bool]
198+ _add_argument (
199+ parser = parser ,
200+ name = f"{ arg_name } .{ i } " ,
201+ arg_type = type (value ),
202+ default_value = value ,
203+ )
179204 #################################
180205 # MARK: List[str|int|float|bool]
181206 _add_argument (
@@ -414,6 +439,21 @@ def parse_extra_args_model(model: "BaseModel"):
414439
415440 _log .debug (f"Set dict key '{ key } ' on parent model '{ parent_model .__class__ .__name__ } ' with value '{ value } '" )
416441
442+ # Now adjust our variable accounting to set the whole dict back on the parent model,
443+ # allowing us to trigger any validation
444+ key = part
445+ value = model_to_set
446+ model_to_set = parent_model
447+ elif isinstance (model_to_set , list ):
448+ if value is None :
449+ continue
450+
451+ # We allow setting list values directly
452+ # Grab the list from the parent model, set the value, and continue
453+ model_to_set [int (key )] = value
454+
455+ _log .debug (f"Set list index '{ key } ' on parent model '{ parent_model .__class__ .__name__ } ' with value '{ value } '" )
456+
417457 # Now adjust our variable accounting to set the whole dict back on the parent model,
418458 # allowing us to trigger any validation
419459 key = part
@@ -427,46 +467,44 @@ def parse_extra_args_model(model: "BaseModel"):
427467 field = model_to_set .__class__ .model_fields [key ]
428468 adapter = TypeAdapter (field .annotation )
429469
430- _log .debug (f"Setting field '{ key } ' on model '{ model_to_set .__class__ .__name__ } ' with raw value '{ value } '" )
431-
432- # Convert the value using the type adapter
433- if get_origin (field .annotation ) in (list , List ):
434- value = value or ""
435- if isinstance (value , list ):
436- # Already a list, use as is
437- pass
438- elif isinstance (value , str ):
439- # Convert from comma-separated values
440- value = value .split ("," )
441- else :
442- # Unknown, raise
443- raise ValueError (f"Cannot convert value '{ value } ' to list for field '{ key } '" )
444- elif get_origin (field .annotation ) in (dict , Dict ):
445- value = value or ""
446- if isinstance (value , dict ):
447- # Already a dict, use as is
448- pass
449- elif isinstance (value , str ):
450- # Convert from comma-separated key=value pairs
451- dict_items = value .split ("," )
452- dict_value = {}
453- for item in dict_items :
454- if item :
455- k , v = item .split ("=" , 1 )
456- # If the key type is an enum, convert
457- dict_value [k ] = v
458-
459- # Grab any previously existing dict to preserve other keys
460- existing_dict = getattr (model_to_set , key , {}) or {}
461- _log .debug (f"Existing dict for field '{ key } ': { existing_dict } " )
462- _log .debug (f"New dict items for field '{ key } ': { dict_value } " )
463- dict_value .update (existing_dict )
464- value = dict_value
465- else :
466- # Unknown, raise
467- raise ValueError (f"Cannot convert value '{ value } ' to dict for field '{ key } '" )
468- try :
469- if value is not None :
470+ if value is not None :
471+ _log .debug (f"Setting field '{ key } ' on model '{ model_to_set .__class__ .__name__ } ' with raw value '{ value } '" )
472+
473+ # Convert the value using the type adapter
474+ if get_origin (field .annotation ) in (list , List ):
475+ if isinstance (value , list ):
476+ # Already a list, use as is
477+ pass
478+ elif isinstance (value , str ):
479+ # Convert from comma-separated values
480+ value = value .split ("," )
481+ else :
482+ # Unknown, raise
483+ raise ValueError (f"Cannot convert value '{ value } ' to list for field '{ key } '" )
484+ elif get_origin (field .annotation ) in (dict , Dict ):
485+ if isinstance (value , dict ):
486+ # Already a dict, use as is
487+ pass
488+ elif isinstance (value , str ):
489+ # Convert from comma-separated key=value pairs
490+ dict_items = value .split ("," )
491+ dict_value = {}
492+ for item in dict_items :
493+ if item :
494+ k , v = item .split ("=" , 1 )
495+ # If the key type is an enum, convert
496+ dict_value [k ] = v
497+
498+ # Grab any previously existing dict to preserve other keys
499+ existing_dict = getattr (model_to_set , key , {}) or {}
500+ _log .debug (f"Existing dict for field '{ key } ': { existing_dict } " )
501+ _log .debug (f"New dict items for field '{ key } ': { dict_value } " )
502+ dict_value .update (existing_dict )
503+ value = dict_value
504+ else :
505+ # Unknown, raise
506+ raise ValueError (f"Cannot convert value '{ value } ' to dict for field '{ key } '" )
507+ try :
470508 # Post process and convert keys if needed
471509 # pydantic shouldve done this automatically, but alas
472510 if isinstance (value , dict ) and get_args (field .annotation ):
@@ -482,10 +520,11 @@ def parse_extra_args_model(model: "BaseModel"):
482520
483521 # Set the value on the model
484522 setattr (model_to_set , key , value )
485-
486- except ValidationError :
487- _log .warning (f"Failed to validate field '{ key } ' with value '{ value } ' for model '{ model_to_set .__class__ .__name__ } '" )
488- continue
523+ except ValidationError :
524+ _log .warning (f"Failed to validate field '{ key } ' with value '{ value } ' for model '{ model_to_set .__class__ .__name__ } '" )
525+ continue
526+ else :
527+ _log .debug (f"Skipping setting field '{ key } ' on model '{ model_to_set .__class__ .__name__ } ' with None value" )
489528
490529 return model , kwargs
491530
0 commit comments