Skip to content

Commit c3dbf72

Browse files
authored
Merge pull request #14 from python-project-templates/tkp/hf
Handle list of models, dict with model values, and existing dict
2 parents 953cd57 + 7a14cff commit c3dbf72

File tree

2 files changed

+232
-39
lines changed

2 files changed

+232
-39
lines changed

hatch_build/cli.py

Lines changed: 216 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from argparse import ArgumentParser
2-
from logging import getLogger
2+
from logging import Formatter, StreamHandler, getLogger
33
from pathlib import Path
44
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, get_args, get_origin
55

@@ -16,6 +16,10 @@
1616
_extras = None
1717

1818
_log = getLogger(__name__)
19+
_handler = StreamHandler()
20+
_formatter = Formatter("[%(asctime)s][%(name)s][%(levelname)s]: %(message)s", datefmt="%Y-%m-%dT%H:%M:%S%z")
21+
_handler.setFormatter(_formatter)
22+
_log.addHandler(_handler)
1923

2024

2125
def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
@@ -27,18 +31,34 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
2731

2832
def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["BaseModel"]], prefix: str = ""):
2933
from pydantic import BaseModel
34+
from pydantic_core import PydanticUndefined
3035

36+
# Model is required
3137
if model is None:
3238
raise ValueError("Model instance cannot be None")
39+
40+
# Extract the fields from a model instance or class
3341
if isinstance(model, type):
3442
model_fields = model.model_fields
3543
else:
3644
model_fields = model.__class__.model_fields
45+
46+
# For each available field, add an argument to the parser
3747
for field_name, field in model_fields.items():
48+
# Grab the annotation to map to type
3849
field_type = field.annotation
39-
arg_name = f"--{prefix}{field_name.replace('_', '-')}"
50+
# Build the argument name converting underscores to dashes
51+
arg_name = f"--{prefix.replace('_', '-')}{field_name.replace('_', '-')}"
52+
53+
# If theres an instance, use that so we have concrete values
54+
model_instance = model if not isinstance(model, type) else None
4055

41-
# Wrappers
56+
# If we have an instance, grab the field value
57+
field_instance = getattr(model_instance, field_name, None) if model_instance else None
58+
59+
# MARK: Wrappers:
60+
# - Optional[T]
61+
# - Union[T, None]
4262
if get_origin(field_type) is Optional:
4363
field_type = get_args(field_type)[0]
4464
elif get_origin(field_type) is Union:
@@ -49,44 +69,126 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
4969
_log.warning(f"Unsupported Union type for argument '{field_name}': {field_type}")
5070
continue
5171

72+
# Default value, promote PydanticUndefined to None
73+
if field.default is PydanticUndefined:
74+
default_value = None
75+
else:
76+
default_value = field.default
77+
5278
# Handled types
79+
# - bool, str, int, float
80+
# - Path
81+
# - Nested BaseModel
82+
# - Literal
83+
# - List[T]
84+
# - where T is bool, str, int, float
85+
# - List[BaseModel] where we have an instance to recurse into
86+
# - Dict[str, T]
87+
# - where T is bool, str, int, float
88+
# - Dict[str, BaseModel] where we have an instance to recurse into
5389
if field_type is bool:
54-
parser.add_argument(arg_name, action="store_true", default=field.default)
90+
#############
91+
# MARK: bool
92+
parser.add_argument(arg_name, action="store_true", default=default_value)
5593
elif field_type in (str, int, float):
94+
########################
95+
# MARK: str, int, float
5696
try:
57-
parser.add_argument(arg_name, type=field_type, default=field.default)
97+
parser.add_argument(arg_name, type=field_type, default=default_value)
5898
except TypeError:
5999
# TODO: handle more complex types if needed
60-
parser.add_argument(arg_name, type=str, default=field.default)
100+
parser.add_argument(arg_name, type=str, default=default_value)
61101
elif isinstance(field_type, type) and issubclass(field_type, Path):
102+
#############
103+
# MARK: Path
62104
# Promote to/from string
63-
parser.add_argument(arg_name, type=str, default=str(field.default) if isinstance(field.default, Path) else None)
64-
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
105+
parser.add_argument(arg_name, type=str, default=str(default_value) if isinstance(default_value, Path) else None)
106+
elif isinstance(field_instance, BaseModel):
107+
############################
108+
# MARK: instance(BaseModel)
65109
# Nested model, add its fields with a prefix
110+
_recurse_add_fields(parser, field_instance, prefix=f"{field_name}.")
111+
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
112+
########################
113+
# MARK: type(BaseModel)
114+
# Nested model class, add its fields with a prefix
66115
_recurse_add_fields(parser, field_type, prefix=f"{field_name}.")
67116
elif get_origin(field_type) is Literal:
117+
################
118+
# MARK: Literal
68119
literal_args = get_args(field_type)
69120
if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args):
70-
_log.warning(f"Only Literal types of str, int, float, or bool are supported - got {literal_args}")
71-
else:
72-
parser.add_argument(arg_name, type=type(literal_args[0]), choices=literal_args, default=field.default)
121+
# Only support simple literal types for now
122+
_log.warning(f"Only Literal types of str, int, float, or bool are supported - field `{field_name}` got {literal_args}")
123+
continue
124+
####################################
125+
# MARK: Literal[str|int|float|bool]
126+
parser.add_argument(arg_name, type=type(literal_args[0]), choices=literal_args, default=default_value)
73127
elif get_origin(field_type) in (list, List):
74-
# TODO: if list arg is complex type, warn as not implemented for now
128+
################
129+
# MARK: List[T]
75130
if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool):
76-
_log.warning(f"Only lists of str, int, float, or bool are supported - got {get_args(field_type)[0]}")
77-
else:
78-
parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)) if isinstance(field, str) else None)
131+
# If theres already something here, we can procede by adding the command with a positional indicator
132+
if field_instance:
133+
########################
134+
# MARK: List[BaseModel]
135+
for i, value in enumerate(field_instance):
136+
_recurse_add_fields(parser, value, prefix=f"{field_name}.{i}.")
137+
continue
138+
# If there's nothing here, we don't know how to address them
139+
# TODO: we could just prefill e.g. --field.0, --field.1 up to some limit
140+
_log.warning(f"Only lists of str, int, float, or bool are supported - field `{field_name}` got {get_args(field_type)[0]}")
141+
continue
142+
#################################
143+
# MARK: List[str|int|float|bool]
144+
parser.add_argument(arg_name, type=str, default=",".join(map(str, default_value)) if isinstance(field, str) else None)
79145
elif get_origin(field_type) in (dict, Dict):
80-
# TODO: if key args are complex type, warn as not implemented for now
146+
######################
147+
# MARK: Dict[str, T]
81148
key_type, value_type = get_args(field_type)
82-
if key_type not in (str, int, float, bool):
83-
_log.warning(f"Only dicts with str keys are supported - got key type {key_type}")
84-
if value_type not in (str, int, float, bool):
85-
_log.warning(f"Only dicts with str values are supported - got value type {value_type}")
86-
else:
87-
parser.add_argument(
88-
arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None
89-
)
149+
150+
if key_type not in (str, int, float, bool) and not (
151+
get_origin(key_type) is Literal and all(isinstance(arg, (str, int, float, bool)) for arg in get_args(key_type))
152+
):
153+
# Check Key type, must be str, int, float, bool
154+
_log.warning(f"Only dicts with str keys are supported - field `{field_name}` got key type {key_type}")
155+
continue
156+
157+
if value_type not in (str, int, float, bool) and not field_instance:
158+
# Check Value type, must be str, int, float, bool if an instance isnt provided
159+
_log.warning(f"Only dicts with str values are supported - field `{field_name}` got value type {value_type}")
160+
continue
161+
162+
# If theres already something here, we can procede by adding the command by keyword
163+
if field_instance:
164+
if all(isinstance(v, BaseModel) for v in field_instance.values()):
165+
#############################
166+
# MARK: Dict[str, BaseModel]
167+
for key, value in field_instance.items():
168+
_recurse_add_fields(parser, value, prefix=f"{field_name}.{key}.")
169+
continue
170+
# If we have mixed, we don't support
171+
elif any(isinstance(v, BaseModel) for v in field_instance.values()):
172+
_log.warning(f"Mixed dict value types are not supported - field `{field_name}` has mixed BaseModel and non-BaseModel values")
173+
continue
174+
# If we have non BaseModel values, we can still add a parser by route
175+
if all(isinstance(v, (str, int, float, bool)) for v in field_instance.values()):
176+
# We can set "known" values here
177+
for key, value in field_instance.items():
178+
##########################################
179+
# MARK: Dict[str, str|int|float|bool]
180+
parser.add_argument(
181+
f"{arg_name}.{key}",
182+
type=type(value),
183+
default=value,
184+
)
185+
# NOTE: don't continue to allow adding the full setter below
186+
# Finally add the full setter for unknown values
187+
##########################################
188+
# MARK: Dict[str, str|int|float|bool|str]
189+
parser.add_argument(
190+
arg_name, type=str, default=",".join(f"{k}={v}" for k, v in default_value.items()) if isinstance(default_value, dict) else None
191+
)
90192
else:
91193
_log.warning(f"Unsupported field type for argument '{field_name}': {field_type}")
92194
return parser
@@ -107,20 +209,46 @@ def parse_extra_args_model(model: "BaseModel"):
107209
for key, value in args.items():
108210
# Handle nested fields
109211
if "." in key:
212+
# We're going to walk down the model tree to get to the right sub-model
110213
parts = key.split(".")
214+
215+
# Accounting
111216
sub_model = model
112-
for part in parts[:-1]:
113-
model_to_set = getattr(sub_model, part)
217+
parent_model = None
218+
219+
for i, part in enumerate(parts[:-1]):
220+
if part.isdigit() and isinstance(sub_model, list):
221+
# List index
222+
index = int(part)
223+
224+
# Should never be out of bounds, but check to be sure
225+
if index >= len(sub_model):
226+
raise IndexError(f"Index {index} out of range for field '{parts[i - 1]}' on model '{parent_model.__class__.__name__}'")
227+
228+
# Grab the model instance from the list
229+
model_to_set = sub_model[index]
230+
elif isinstance(sub_model, dict):
231+
# Dict key
232+
if part not in sub_model:
233+
raise KeyError(f"Key '{part}' not found for field '{parts[i - 1]}' on model '{parent_model.__class__.__name__}'")
234+
235+
# Grab the model instance from the dict
236+
model_to_set = sub_model[part]
237+
else:
238+
model_to_set = getattr(sub_model, part)
239+
114240
if model_to_set is None:
115241
# Create a new instance of model
116242
field = sub_model.__class__.model_fields[part]
117-
# if field annotation is an optional or union with none, extrat type
243+
244+
# if field annotation is an optional or union with none, extract type
118245
if get_origin(field.annotation) is Optional:
119246
model_to_instance = get_args(field.annotation)[0]
120247
elif get_origin(field.annotation) is Union:
121248
non_none_types = [t for t in get_args(field.annotation) if t is not type(None)]
122249
if len(non_none_types) == 1:
123250
model_to_instance = non_none_types[0]
251+
124252
else:
125253
model_to_instance = field.annotation
126254
if not isinstance(model_to_instance, type) or not issubclass(model_to_instance, BaseModel):
@@ -129,35 +257,85 @@ def parse_extra_args_model(model: "BaseModel"):
129257
)
130258
model_to_set = model_to_instance()
131259
setattr(sub_model, part, model_to_set)
260+
261+
parent_model = sub_model
262+
sub_model = model_to_set
263+
132264
key = parts[-1]
133265
else:
266+
# Accounting
267+
sub_model = model
268+
parent_model = model
134269
model_to_set = model
135270

271+
if not isinstance(model_to_set, BaseModel):
272+
if isinstance(model_to_set, dict):
273+
# We allow setting dict values directly
274+
# Grab the dict from the parent model, set the value, and continue
275+
if key in model_to_set:
276+
model_to_set[key] = value
277+
elif key.replace("_", "-") in model_to_set:
278+
# Argparse converts dashes back to underscores, so undo
279+
model_to_set[key.replace("_", "-")] = value
280+
else:
281+
# Raise
282+
raise KeyError(f"Key '{key}' not found in dict field on model '{parent_model.__class__.__name__}'")
283+
284+
# Now adjust our variable accounting to set the whole dict back on the parent model,
285+
# allowing us to trigger any validation
286+
key = part
287+
value = model_to_set
288+
model_to_set = parent_model
289+
else:
290+
_log.warning(f"Cannot set field '{key}' on non-BaseModel instance of type '{type(model_to_set).__name__}'")
291+
continue
292+
136293
# Grab the field from the model class and make a type adapter
137294
field = model_to_set.__class__.model_fields[key]
138295
adapter = TypeAdapter(field.annotation)
139296

140297
# Convert the value using the type adapter
141298
if get_origin(field.annotation) in (list, List):
142299
value = value or ""
143-
value = value.split(",")
300+
if isinstance(value, list):
301+
# Already a list, use as is
302+
pass
303+
elif isinstance(value, str):
304+
# Convert from comma-separated values
305+
value = value.split(",")
306+
else:
307+
# Unknown, raise
308+
raise ValueError(f"Cannot convert value '{value}' to list for field '{key}'")
144309
elif get_origin(field.annotation) in (dict, Dict):
145310
value = value or ""
146-
dict_items = value.split(",")
147-
dict_value = {}
148-
for item in dict_items:
149-
if item:
150-
k, v = item.split("=", 1)
151-
dict_value[k] = v
152-
value = dict_value
311+
if isinstance(value, dict):
312+
# Already a dict, use as is
313+
pass
314+
elif isinstance(value, str):
315+
# Convert from comma-separated key=value pairs
316+
dict_items = value.split(",")
317+
dict_value = {}
318+
for item in dict_items:
319+
if item:
320+
k, v = item.split("=", 1)
321+
dict_value[k] = v
322+
# Grab any previously existing dict to preserve other keys
323+
existing_dict = getattr(model_to_set, key, {}) or {}
324+
dict_value.update(existing_dict)
325+
value = dict_value
326+
else:
327+
# Unknown, raise
328+
raise ValueError(f"Cannot convert value '{value}' to dict for field '{key}'")
153329
try:
154-
value = adapter.validate_python(value)
330+
if value is not None:
331+
value = adapter.validate_python(value)
332+
333+
# Set the value on the model
334+
setattr(model_to_set, key, value)
155335
except ValidationError:
156336
_log.warning(f"Failed to validate field '{key}' with value '{value}' for model '{model_to_set.__class__.__name__}'")
157337
continue
158338

159-
# Set the value on the model
160-
setattr(model_to_set, key, value)
161339
return model, kwargs
162340

163341

0 commit comments

Comments
 (0)