11from argparse import ArgumentParser
2- from logging import getLogger
2+ from logging import Formatter , StreamHandler , getLogger
33from pathlib import Path
44from typing import TYPE_CHECKING , Callable , Dict , List , Literal , Optional , Tuple , Type , Union , get_args , get_origin
55
1616_extras = None
1717
1818_log = getLogger (__name__ )
19+ _handler = StreamHandler ()
20+ _formatter = Formatter ("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s" , datefmt = "%Y-%m-%dT%H:%M:%S%z" )
21+ _handler .setFormatter (_formatter )
22+ _log .addHandler (_handler )
1923
2024
2125def parse_extra_args (subparser : Optional [ArgumentParser ] = None ) -> List [str ]:
@@ -27,18 +31,34 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
2731
2832def _recurse_add_fields (parser : ArgumentParser , model : Union ["BaseModel" , Type ["BaseModel" ]], prefix : str = "" ):
2933 from pydantic import BaseModel
34+ from pydantic_core import PydanticUndefined
3035
36+ # Model is required
3137 if model is None :
3238 raise ValueError ("Model instance cannot be None" )
39+
40+ # Extract the fields from a model instance or class
3341 if isinstance (model , type ):
3442 model_fields = model .model_fields
3543 else :
3644 model_fields = model .__class__ .model_fields
45+
46+ # For each available field, add an argument to the parser
3747 for field_name , field in model_fields .items ():
48+ # Grab the annotation to map to type
3849 field_type = field .annotation
39- arg_name = f"--{ prefix } { field_name .replace ('_' , '-' )} "
50+ # Build the argument name converting underscores to dashes
51+ arg_name = f"--{ prefix .replace ('_' , '-' )} { field_name .replace ('_' , '-' )} "
52+
53+ # If theres an instance, use that so we have concrete values
54+ model_instance = model if not isinstance (model , type ) else None
4055
41- # Wrappers
56+ # If we have an instance, grab the field value
57+ field_instance = getattr (model_instance , field_name , None ) if model_instance else None
58+
59+ # MARK: Wrappers:
60+ # - Optional[T]
61+ # - Union[T, None]
4262 if get_origin (field_type ) is Optional :
4363 field_type = get_args (field_type )[0 ]
4464 elif get_origin (field_type ) is Union :
@@ -49,44 +69,126 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
4969 _log .warning (f"Unsupported Union type for argument '{ field_name } ': { field_type } " )
5070 continue
5171
72+ # Default value, promote PydanticUndefined to None
73+ if field .default is PydanticUndefined :
74+ default_value = None
75+ else :
76+ default_value = field .default
77+
5278 # Handled types
79+ # - bool, str, int, float
80+ # - Path
81+ # - Nested BaseModel
82+ # - Literal
83+ # - List[T]
84+ # - where T is bool, str, int, float
85+ # - List[BaseModel] where we have an instance to recurse into
86+ # - Dict[str, T]
87+ # - where T is bool, str, int, float
88+ # - Dict[str, BaseModel] where we have an instance to recurse into
5389 if field_type is bool :
54- parser .add_argument (arg_name , action = "store_true" , default = field .default )
90+ #############
91+ # MARK: bool
92+ parser .add_argument (arg_name , action = "store_true" , default = default_value )
5593 elif field_type in (str , int , float ):
94+ ########################
95+ # MARK: str, int, float
5696 try :
57- parser .add_argument (arg_name , type = field_type , default = field . default )
97+ parser .add_argument (arg_name , type = field_type , default = default_value )
5898 except TypeError :
5999 # TODO: handle more complex types if needed
60- parser .add_argument (arg_name , type = str , default = field . default )
100+ parser .add_argument (arg_name , type = str , default = default_value )
61101 elif isinstance (field_type , type ) and issubclass (field_type , Path ):
102+ #############
103+ # MARK: Path
62104 # Promote to/from string
63- parser .add_argument (arg_name , type = str , default = str (field .default ) if isinstance (field .default , Path ) else None )
64- elif isinstance (field_type , Type ) and issubclass (field_type , BaseModel ):
105+ parser .add_argument (arg_name , type = str , default = str (default_value ) if isinstance (default_value , Path ) else None )
106+ elif isinstance (field_instance , BaseModel ):
107+ ############################
108+ # MARK: instance(BaseModel)
65109 # Nested model, add its fields with a prefix
110+ _recurse_add_fields (parser , field_instance , prefix = f"{ field_name } ." )
111+ elif isinstance (field_type , Type ) and issubclass (field_type , BaseModel ):
112+ ########################
113+ # MARK: type(BaseModel)
114+ # Nested model class, add its fields with a prefix
66115 _recurse_add_fields (parser , field_type , prefix = f"{ field_name } ." )
67116 elif get_origin (field_type ) is Literal :
117+ ################
118+ # MARK: Literal
68119 literal_args = get_args (field_type )
69120 if not all (isinstance (arg , (str , int , float , bool )) for arg in literal_args ):
70- _log .warning (f"Only Literal types of str, int, float, or bool are supported - got { literal_args } " )
71- else :
72- parser .add_argument (arg_name , type = type (literal_args [0 ]), choices = literal_args , default = field .default )
121+ # Only support simple literal types for now
122+ _log .warning (f"Only Literal types of str, int, float, or bool are supported - field `{ field_name } ` got { literal_args } " )
123+ continue
124+ ####################################
125+ # MARK: Literal[str|int|float|bool]
126+ parser .add_argument (arg_name , type = type (literal_args [0 ]), choices = literal_args , default = default_value )
73127 elif get_origin (field_type ) in (list , List ):
74- # TODO: if list arg is complex type, warn as not implemented for now
128+ ################
129+ # MARK: List[T]
75130 if get_args (field_type ) and get_args (field_type )[0 ] not in (str , int , float , bool ):
76- _log .warning (f"Only lists of str, int, float, or bool are supported - got { get_args (field_type )[0 ]} " )
77- else :
78- parser .add_argument (arg_name , type = str , default = "," .join (map (str , field .default )) if isinstance (field , str ) else None )
131+ # If theres already something here, we can procede by adding the command with a positional indicator
132+ if field_instance :
133+ ########################
134+ # MARK: List[BaseModel]
135+ for i , value in enumerate (field_instance ):
136+ _recurse_add_fields (parser , value , prefix = f"{ field_name } .{ i } ." )
137+ continue
138+ # If there's nothing here, we don't know how to address them
139+ # TODO: we could just prefill e.g. --field.0, --field.1 up to some limit
140+ _log .warning (f"Only lists of str, int, float, or bool are supported - field `{ field_name } ` got { get_args (field_type )[0 ]} " )
141+ continue
142+ #################################
143+ # MARK: List[str|int|float|bool]
144+ parser .add_argument (arg_name , type = str , default = "," .join (map (str , default_value )) if isinstance (field , str ) else None )
79145 elif get_origin (field_type ) in (dict , Dict ):
80- # TODO: if key args are complex type, warn as not implemented for now
146+ ######################
147+ # MARK: Dict[str, T]
81148 key_type , value_type = get_args (field_type )
82- if key_type not in (str , int , float , bool ):
83- _log .warning (f"Only dicts with str keys are supported - got key type { key_type } " )
84- if value_type not in (str , int , float , bool ):
85- _log .warning (f"Only dicts with str values are supported - got value type { value_type } " )
86- else :
87- parser .add_argument (
88- arg_name , type = str , default = "," .join (f"{ k } ={ v } " for k , v in field .default .items ()) if isinstance (field .default , dict ) else None
89- )
149+
150+ if key_type not in (str , int , float , bool ) and not (
151+ get_origin (key_type ) is Literal and all (isinstance (arg , (str , int , float , bool )) for arg in get_args (key_type ))
152+ ):
153+ # Check Key type, must be str, int, float, bool
154+ _log .warning (f"Only dicts with str keys are supported - field `{ field_name } ` got key type { key_type } " )
155+ continue
156+
157+ if value_type not in (str , int , float , bool ) and not field_instance :
158+ # Check Value type, must be str, int, float, bool if an instance isnt provided
159+ _log .warning (f"Only dicts with str values are supported - field `{ field_name } ` got value type { value_type } " )
160+ continue
161+
162+ # If theres already something here, we can procede by adding the command by keyword
163+ if field_instance :
164+ if all (isinstance (v , BaseModel ) for v in field_instance .values ()):
165+ #############################
166+ # MARK: Dict[str, BaseModel]
167+ for key , value in field_instance .items ():
168+ _recurse_add_fields (parser , value , prefix = f"{ field_name } .{ key } ." )
169+ continue
170+ # If we have mixed, we don't support
171+ elif any (isinstance (v , BaseModel ) for v in field_instance .values ()):
172+ _log .warning (f"Mixed dict value types are not supported - field `{ field_name } ` has mixed BaseModel and non-BaseModel values" )
173+ continue
174+ # If we have non BaseModel values, we can still add a parser by route
175+ if all (isinstance (v , (str , int , float , bool )) for v in field_instance .values ()):
176+ # We can set "known" values here
177+ for key , value in field_instance .items ():
178+ ##########################################
179+ # MARK: Dict[str, str|int|float|bool]
180+ parser .add_argument (
181+ f"{ arg_name } .{ key } " ,
182+ type = type (value ),
183+ default = value ,
184+ )
185+ # NOTE: don't continue to allow adding the full setter below
186+ # Finally add the full setter for unknown values
187+ ##########################################
188+ # MARK: Dict[str, str|int|float|bool|str]
189+ parser .add_argument (
190+ arg_name , type = str , default = "," .join (f"{ k } ={ v } " for k , v in default_value .items ()) if isinstance (default_value , dict ) else None
191+ )
90192 else :
91193 _log .warning (f"Unsupported field type for argument '{ field_name } ': { field_type } " )
92194 return parser
@@ -107,20 +209,46 @@ def parse_extra_args_model(model: "BaseModel"):
107209 for key , value in args .items ():
108210 # Handle nested fields
109211 if "." in key :
212+ # We're going to walk down the model tree to get to the right sub-model
110213 parts = key .split ("." )
214+
215+ # Accounting
111216 sub_model = model
112- for part in parts [:- 1 ]:
113- model_to_set = getattr (sub_model , part )
217+ parent_model = None
218+
219+ for i , part in enumerate (parts [:- 1 ]):
220+ if part .isdigit () and isinstance (sub_model , list ):
221+ # List index
222+ index = int (part )
223+
224+ # Should never be out of bounds, but check to be sure
225+ if index >= len (sub_model ):
226+ raise IndexError (f"Index { index } out of range for field '{ parts [i - 1 ]} ' on model '{ parent_model .__class__ .__name__ } '" )
227+
228+ # Grab the model instance from the list
229+ model_to_set = sub_model [index ]
230+ elif isinstance (sub_model , dict ):
231+ # Dict key
232+ if part not in sub_model :
233+ raise KeyError (f"Key '{ part } ' not found for field '{ parts [i - 1 ]} ' on model '{ parent_model .__class__ .__name__ } '" )
234+
235+ # Grab the model instance from the dict
236+ model_to_set = sub_model [part ]
237+ else :
238+ model_to_set = getattr (sub_model , part )
239+
114240 if model_to_set is None :
115241 # Create a new instance of model
116242 field = sub_model .__class__ .model_fields [part ]
117- # if field annotation is an optional or union with none, extrat type
243+
244+ # if field annotation is an optional or union with none, extract type
118245 if get_origin (field .annotation ) is Optional :
119246 model_to_instance = get_args (field .annotation )[0 ]
120247 elif get_origin (field .annotation ) is Union :
121248 non_none_types = [t for t in get_args (field .annotation ) if t is not type (None )]
122249 if len (non_none_types ) == 1 :
123250 model_to_instance = non_none_types [0 ]
251+
124252 else :
125253 model_to_instance = field .annotation
126254 if not isinstance (model_to_instance , type ) or not issubclass (model_to_instance , BaseModel ):
@@ -129,35 +257,85 @@ def parse_extra_args_model(model: "BaseModel"):
129257 )
130258 model_to_set = model_to_instance ()
131259 setattr (sub_model , part , model_to_set )
260+
261+ parent_model = sub_model
262+ sub_model = model_to_set
263+
132264 key = parts [- 1 ]
133265 else :
266+ # Accounting
267+ sub_model = model
268+ parent_model = model
134269 model_to_set = model
135270
271+ if not isinstance (model_to_set , BaseModel ):
272+ if isinstance (model_to_set , dict ):
273+ # We allow setting dict values directly
274+ # Grab the dict from the parent model, set the value, and continue
275+ if key in model_to_set :
276+ model_to_set [key ] = value
277+ elif key .replace ("_" , "-" ) in model_to_set :
278+ # Argparse converts dashes back to underscores, so undo
279+ model_to_set [key .replace ("_" , "-" )] = value
280+ else :
281+ # Raise
282+ raise KeyError (f"Key '{ key } ' not found in dict field on model '{ parent_model .__class__ .__name__ } '" )
283+
284+ # Now adjust our variable accounting to set the whole dict back on the parent model,
285+ # allowing us to trigger any validation
286+ key = part
287+ value = model_to_set
288+ model_to_set = parent_model
289+ else :
290+ _log .warning (f"Cannot set field '{ key } ' on non-BaseModel instance of type '{ type (model_to_set ).__name__ } '" )
291+ continue
292+
136293 # Grab the field from the model class and make a type adapter
137294 field = model_to_set .__class__ .model_fields [key ]
138295 adapter = TypeAdapter (field .annotation )
139296
140297 # Convert the value using the type adapter
141298 if get_origin (field .annotation ) in (list , List ):
142299 value = value or ""
143- value = value .split ("," )
300+ if isinstance (value , list ):
301+ # Already a list, use as is
302+ pass
303+ elif isinstance (value , str ):
304+ # Convert from comma-separated values
305+ value = value .split ("," )
306+ else :
307+ # Unknown, raise
308+ raise ValueError (f"Cannot convert value '{ value } ' to list for field '{ key } '" )
144309 elif get_origin (field .annotation ) in (dict , Dict ):
145310 value = value or ""
146- dict_items = value .split ("," )
147- dict_value = {}
148- for item in dict_items :
149- if item :
150- k , v = item .split ("=" , 1 )
151- dict_value [k ] = v
152- value = dict_value
311+ if isinstance (value , dict ):
312+ # Already a dict, use as is
313+ pass
314+ elif isinstance (value , str ):
315+ # Convert from comma-separated key=value pairs
316+ dict_items = value .split ("," )
317+ dict_value = {}
318+ for item in dict_items :
319+ if item :
320+ k , v = item .split ("=" , 1 )
321+ dict_value [k ] = v
322+ # Grab any previously existing dict to preserve other keys
323+ existing_dict = getattr (model_to_set , key , {}) or {}
324+ dict_value .update (existing_dict )
325+ value = dict_value
326+ else :
327+ # Unknown, raise
328+ raise ValueError (f"Cannot convert value '{ value } ' to dict for field '{ key } '" )
153329 try :
154- value = adapter .validate_python (value )
330+ if value is not None :
331+ value = adapter .validate_python (value )
332+
333+ # Set the value on the model
334+ setattr (model_to_set , key , value )
155335 except ValidationError :
156336 _log .warning (f"Failed to validate field '{ key } ' with value '{ value } ' for model '{ model_to_set .__class__ .__name__ } '" )
157337 continue
158338
159- # Set the value on the model
160- setattr (model_to_set , key , value )
161339 return model , kwargs
162340
163341
0 commit comments