11from argparse import ArgumentParser
22from logging import getLogger
3- from typing import TYPE_CHECKING , Callable , Dict , List , Literal , Optional , Tuple , Type , get_args , get_origin
3+ from pathlib import Path
4+ from typing import TYPE_CHECKING , Callable , Dict , List , Literal , Optional , Tuple , Type , Union , get_args , get_origin
45
56from hatchling .cli .build import build_command
67
@@ -24,12 +25,31 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
2425 return vars (kwargs ), extras
2526
2627
27- def _recurse_add_fields (parser : ArgumentParser , model : "BaseModel" , prefix : str = "" ):
28+ def _recurse_add_fields (parser : ArgumentParser , model : Union [ "BaseModel" , Type [ "BaseModel" ]] , prefix : str = "" ):
2829 from pydantic import BaseModel
2930
30- for field_name , field in model .__class__ .model_fields .items ():
31+ if model is None :
32+ raise ValueError ("Model instance cannot be None" )
33+ if isinstance (model , type ):
34+ model_fields = model .model_fields
35+ else :
36+ model_fields = model .__class__ .model_fields
37+ for field_name , field in model_fields .items ():
3138 field_type = field .annotation
3239 arg_name = f"--{ prefix } { field_name .replace ('_' , '-' )} "
40+
41+ # Wrappers
42+ if get_origin (field_type ) is Optional :
43+ field_type = get_args (field_type )[0 ]
44+ elif get_origin (field_type ) is Union :
45+ non_none_types = [t for t in get_args (field_type ) if t is not type (None )]
46+ if len (non_none_types ) == 1 :
47+ field_type = non_none_types [0 ]
48+ else :
49+ _log .warning (f"Unsupported Union type for argument '{ field_name } ': { field_type } " )
50+ continue
51+
52+ # Handled types
3353 if field_type is bool :
3454 parser .add_argument (arg_name , action = "store_true" , default = field .default )
3555 elif field_type in (str , int , float ):
@@ -38,9 +58,12 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
3858 except TypeError :
3959 # TODO: handle more complex types if needed
4060 parser .add_argument (arg_name , type = str , default = field .default )
61+ elif isinstance (field_type , type ) and issubclass (field_type , Path ):
62+ # Promote to/from string
63+ parser .add_argument (arg_name , type = str , default = str (field .default ) if isinstance (field .default , Path ) else None )
4164 elif isinstance (field_type , Type ) and issubclass (field_type , BaseModel ):
4265 # Nested model, add its fields with a prefix
43- _recurse_add_fields (parser , getattr ( model , field_name ) , prefix = f"{ field_name } ." )
66+ _recurse_add_fields (parser , field_type , prefix = f"{ field_name } ." )
4467 elif get_origin (field_type ) is Literal :
4568 literal_args = get_args (field_type )
4669 if not all (isinstance (arg , (str , int , float , bool )) for arg in literal_args ):
@@ -65,13 +88,13 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
6588 arg_name , type = str , default = "," .join (f"{ k } ={ v } " for k , v in field .default .items ()) if isinstance (field .default , dict ) else None
6689 )
6790 else :
68- _log .warning (f"Unsupported field type for argument '{ arg_name } ': { field_type } " )
91+ _log .warning (f"Unsupported field type for argument '{ field_name } ': { field_type } " )
6992 return parser
7093
7194
7295def parse_extra_args_model (model : "BaseModel" ):
7396 try :
74- from pydantic import TypeAdapter
97+ from pydantic import BaseModel , TypeAdapter
7598 except ImportError :
7699 raise ImportError ("pydantic is required to use parse_extra_args_model" )
77100 # Recursively parse fields from a pydantic model and its sub-models
@@ -88,6 +111,24 @@ def parse_extra_args_model(model: "BaseModel"):
88111 sub_model = model
89112 for part in parts [:- 1 ]:
90113 model_to_set = getattr (sub_model , part )
114+ if model_to_set is None :
115+ # Create a new instance of model
116+ field = sub_model .__class__ .model_fields [part ]
117+ # if field annotation is an optional or union with none, extrat type
118+ if get_origin (field .annotation ) is Optional :
119+ model_to_instance = get_args (field .annotation )[0 ]
120+ elif get_origin (field .annotation ) is Union :
121+ non_none_types = [t for t in get_args (field .annotation ) if t is not type (None )]
122+ if len (non_none_types ) == 1 :
123+ model_to_instance = non_none_types [0 ]
124+ else :
125+ model_to_instance = field .annotation
126+ if not isinstance (model_to_instance , type ) or not issubclass (model_to_instance , BaseModel ):
127+ raise ValueError (
128+ f"Cannot create sub-model for field '{ part } ' on model '{ sub_model .__class__ .__name__ } ': - type is { model_to_instance } "
129+ )
130+ model_to_set = model_to_instance ()
131+ setattr (sub_model , part , model_to_set )
91132 key = parts [- 1 ]
92133 else :
93134 model_to_set = model
0 commit comments