11from argparse import ArgumentParser
2- from typing import Callable , List , Optional , Tuple
2+ from typing import TYPE_CHECKING , Callable , Dict , List , Optional , Tuple , Type , get_args , get_origin
33
44from hatchling .cli .build import build_command
55
6+ if TYPE_CHECKING :
7+ from pydantic import BaseModel
8+
69__all__ = (
710 "hatchling" ,
811 "parse_extra_args" ,
@@ -17,6 +20,84 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
1720 return vars (kwargs ), extras
1821
1922
23+ def recurse_add_fields (parser : ArgumentParser , model : "BaseModel" , prefix : str = "" ):
24+ from pydantic import BaseModel
25+
26+ for field_name , field in model .__class__ .model_fields .items ():
27+ field_type = field .annotation
28+ arg_name = f"--{ prefix } { field_name .replace ('_' , '-' )} "
29+ if field_type is bool :
30+ parser .add_argument (arg_name , action = "store_true" , default = field .default )
31+ elif isinstance (field_type , Type ) and issubclass (field_type , BaseModel ):
32+ # Nested model, add its fields with a prefix
33+ recurse_add_fields (parser , getattr (model , field_name ), prefix = f"{ field_name } ." )
34+ elif get_origin (field_type ) in (list , List ):
35+ # TODO: if list arg is complex type, raise as not implemented for now
36+ if get_args (field_type ) and get_args (field_type )[0 ] not in (str , int , float , bool ):
37+ raise NotImplementedError ("Only lists of str, int, float, or bool are supported" )
38+ parser .add_argument (arg_name , type = str , default = "," .join (map (str , field .default )))
39+ elif get_origin (field_type ) in (dict , Dict ):
40+ # TODO: if key args are complex type, raise as not implemented for now
41+ key_type , value_type = get_args (field_type )
42+ if key_type not in (str , int , float , bool ):
43+ raise NotImplementedError ("Only dicts with str keys are supported" )
44+ if value_type not in (str , int , float , bool ):
45+ raise NotImplementedError ("Only dicts with str values are supported" )
46+ parser .add_argument (arg_name , type = str , default = "," .join (f"{ k } ={ v } " for k , v in field .default .items ()))
47+ else :
48+ try :
49+ parser .add_argument (arg_name , type = field_type , default = field .default )
50+ except TypeError :
51+ # TODO: handle more complex types if needed
52+ parser .add_argument (arg_name , type = str , default = field .default )
53+ return parser
54+
55+
56+ def parse_extra_args_model (model : "BaseModel" ):
57+ try :
58+ from pydantic import TypeAdapter
59+ except ImportError :
60+ raise ImportError ("pydantic is required to use parse_extra_args_model" )
61+ # Recursively parse fields from a pydantic model and its sub-models
62+ # and create an argument parser to parse extra args
63+ parser = ArgumentParser (prog = "hatch-build-extras-model" , allow_abbrev = False )
64+ parser = recurse_add_fields (parser , model )
65+
66+ # Parse the extra args and update the model
67+ args , kwargs = parse_extra_args (parser )
68+ for key , value in args .items ():
69+ # Handle nested fields
70+ if "." in key :
71+ parts = key .split ("." )
72+ sub_model = model
73+ for part in parts [:- 1 ]:
74+ model_to_set = getattr (sub_model , part )
75+ key = parts [- 1 ]
76+ else :
77+ model_to_set = model
78+
79+ # Grab the field from the model class and make a type adapter
80+ field = model_to_set .__class__ .model_fields [key ]
81+ adapter = TypeAdapter (field .annotation )
82+
83+ # Convert the value using the type adapter
84+ if get_origin (field .annotation ) in (list , List ):
85+ value = adapter .validate_python (value .split ("," ))
86+ elif get_origin (field .annotation ) in (dict , Dict ):
87+ dict_items = value .split ("," )
88+ dict_value = {}
89+ for item in dict_items :
90+ k , v = item .split ("=" , 1 )
91+ dict_value [k ] = v
92+ value = adapter .validate_python (dict_value )
93+ else :
94+ value = adapter .validate_python (value )
95+
96+ # Set the value on the model
97+ setattr (model_to_set , key , value )
98+ return model , kwargs
99+
100+
20101def _hatchling_internal () -> Tuple [Optional [Callable ], Optional [dict ], List [str ]]:
21102 parser = ArgumentParser (prog = "hatch-build" , allow_abbrev = False )
22103 subparsers = parser .add_subparsers ()
0 commit comments