Skip to content

Commit 4721f31

Browse files
committed
Add fix for default args vs predefined args
1 parent 8c4da7d commit 4721f31

File tree

2 files changed

+127
-54
lines changed

2 files changed

+127
-54
lines changed

hatch_build/cli.py

Lines changed: 88 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,18 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
4444

4545

4646
def _is_supported_type(field_type: type) -> bool:
47-
if not isinstance(field_type, type):
48-
return False
4947
if get_origin(field_type) is Optional:
5048
field_type = get_args(field_type)[0]
5149
elif get_origin(field_type) is Union:
5250
non_none_types = [t for t in get_args(field_type) if t is not type(None)]
51+
if all(_is_supported_type(t) for t in non_none_types):
52+
return True
5353
if len(non_none_types) == 1:
5454
field_type = non_none_types[0]
5555
elif get_origin(field_type) is Literal:
5656
return all(isinstance(arg, (str, int, float, bool, Enum)) for arg in get_args(field_type))
57+
if not isinstance(field_type, type):
58+
return False
5759
return field_type in (str, int, float, bool) or issubclass(field_type, Enum)
5860

5961

@@ -100,6 +102,8 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
100102
# Default value, promote PydanticUndefined to None
101103
if field.default is PydanticUndefined:
102104
default_value = None
105+
elif field_instance:
106+
default_value = field_instance
103107
else:
104108
default_value = field.default
105109

@@ -167,15 +171,36 @@ def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["
167171
if get_args(field_type) and not _is_supported_type(get_args(field_type)[0]):
168172
# If theres already something here, we can procede by adding the command with a positional indicator
169173
if field_instance:
170-
########################
171-
# MARK: List[BaseModel]
172174
for i, value in enumerate(field_instance):
173-
_recurse_add_fields(parser, value, prefix=f"{field_name}.{i}.")
175+
if isinstance(value, BaseModel):
176+
########################
177+
# MARK: List[BaseModel]
178+
_recurse_add_fields(parser, value, prefix=f"{field_name}.{i}.")
179+
continue
180+
else:
181+
########################
182+
# MARK: List[str|int|float|bool]
183+
_add_argument(
184+
parser=parser,
185+
name=f"{arg_name}.{i}",
186+
arg_type=type(value),
187+
default_value=value,
188+
)
174189
continue
175190
# If there's nothing here, we don't know how to address them
176191
# TODO: we could just prefill e.g. --field.0, --field.1 up to some limit
177192
_log.warning(f"Only lists of str, int, float, or bool are supported - field `{field_name}` got {get_args(field_type)[0]}")
178193
continue
194+
if field_instance:
195+
for i, value in enumerate(field_instance):
196+
########################
197+
# MARK: List[str|int|float|bool]
198+
_add_argument(
199+
parser=parser,
200+
name=f"{arg_name}.{i}",
201+
arg_type=type(value),
202+
default_value=value,
203+
)
179204
#################################
180205
# MARK: List[str|int|float|bool]
181206
_add_argument(
@@ -414,6 +439,21 @@ def parse_extra_args_model(model: "BaseModel"):
414439

415440
_log.debug(f"Set dict key '{key}' on parent model '{parent_model.__class__.__name__}' with value '{value}'")
416441

442+
# Now adjust our variable accounting to set the whole dict back on the parent model,
443+
# allowing us to trigger any validation
444+
key = part
445+
value = model_to_set
446+
model_to_set = parent_model
447+
elif isinstance(model_to_set, list):
448+
if value is None:
449+
continue
450+
451+
# We allow setting list values directly
452+
# Grab the list from the parent model, set the value, and continue
453+
model_to_set[int(key)] = value
454+
455+
_log.debug(f"Set list index '{key}' on parent model '{parent_model.__class__.__name__}' with value '{value}'")
456+
417457
# Now adjust our variable accounting to set the whole dict back on the parent model,
418458
# allowing us to trigger any validation
419459
key = part
@@ -427,46 +467,44 @@ def parse_extra_args_model(model: "BaseModel"):
427467
field = model_to_set.__class__.model_fields[key]
428468
adapter = TypeAdapter(field.annotation)
429469

430-
_log.debug(f"Setting field '{key}' on model '{model_to_set.__class__.__name__}' with raw value '{value}'")
431-
432-
# Convert the value using the type adapter
433-
if get_origin(field.annotation) in (list, List):
434-
value = value or ""
435-
if isinstance(value, list):
436-
# Already a list, use as is
437-
pass
438-
elif isinstance(value, str):
439-
# Convert from comma-separated values
440-
value = value.split(",")
441-
else:
442-
# Unknown, raise
443-
raise ValueError(f"Cannot convert value '{value}' to list for field '{key}'")
444-
elif get_origin(field.annotation) in (dict, Dict):
445-
value = value or ""
446-
if isinstance(value, dict):
447-
# Already a dict, use as is
448-
pass
449-
elif isinstance(value, str):
450-
# Convert from comma-separated key=value pairs
451-
dict_items = value.split(",")
452-
dict_value = {}
453-
for item in dict_items:
454-
if item:
455-
k, v = item.split("=", 1)
456-
# If the key type is an enum, convert
457-
dict_value[k] = v
458-
459-
# Grab any previously existing dict to preserve other keys
460-
existing_dict = getattr(model_to_set, key, {}) or {}
461-
_log.debug(f"Existing dict for field '{key}': {existing_dict}")
462-
_log.debug(f"New dict items for field '{key}': {dict_value}")
463-
dict_value.update(existing_dict)
464-
value = dict_value
465-
else:
466-
# Unknown, raise
467-
raise ValueError(f"Cannot convert value '{value}' to dict for field '{key}'")
468-
try:
469-
if value is not None:
470+
if value is not None:
471+
_log.debug(f"Setting field '{key}' on model '{model_to_set.__class__.__name__}' with raw value '{value}'")
472+
473+
# Convert the value using the type adapter
474+
if get_origin(field.annotation) in (list, List):
475+
if isinstance(value, list):
476+
# Already a list, use as is
477+
pass
478+
elif isinstance(value, str):
479+
# Convert from comma-separated values
480+
value = value.split(",")
481+
else:
482+
# Unknown, raise
483+
raise ValueError(f"Cannot convert value '{value}' to list for field '{key}'")
484+
elif get_origin(field.annotation) in (dict, Dict):
485+
if isinstance(value, dict):
486+
# Already a dict, use as is
487+
pass
488+
elif isinstance(value, str):
489+
# Convert from comma-separated key=value pairs
490+
dict_items = value.split(",")
491+
dict_value = {}
492+
for item in dict_items:
493+
if item:
494+
k, v = item.split("=", 1)
495+
# If the key type is an enum, convert
496+
dict_value[k] = v
497+
498+
# Grab any previously existing dict to preserve other keys
499+
existing_dict = getattr(model_to_set, key, {}) or {}
500+
_log.debug(f"Existing dict for field '{key}': {existing_dict}")
501+
_log.debug(f"New dict items for field '{key}': {dict_value}")
502+
dict_value.update(existing_dict)
503+
value = dict_value
504+
else:
505+
# Unknown, raise
506+
raise ValueError(f"Cannot convert value '{value}' to dict for field '{key}'")
507+
try:
470508
# Post process and convert keys if needed
471509
# pydantic shouldve done this automatically, but alas
472510
if isinstance(value, dict) and get_args(field.annotation):
@@ -482,10 +520,11 @@ def parse_extra_args_model(model: "BaseModel"):
482520

483521
# Set the value on the model
484522
setattr(model_to_set, key, value)
485-
486-
except ValidationError:
487-
_log.warning(f"Failed to validate field '{key}' with value '{value}' for model '{model_to_set.__class__.__name__}'")
488-
continue
523+
except ValidationError:
524+
_log.warning(f"Failed to validate field '{key}' with value '{value}' for model '{model_to_set.__class__.__name__}'")
525+
continue
526+
else:
527+
_log.debug(f"Skipping setting field '{key}' on model '{model_to_set.__class__.__name__}' with None value")
489528

490529
return model, kwargs
491530

hatch_build/tests/test_cli_model.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class SubModel(BaseModel, validate_assignment=True):
2121
sub_arg: int = 42
2222
sub_arg_with_value: str = "sub_default"
2323
sub_arg_enum: MyEnum = MyEnum.OPTION_A
24+
sub_arg_literal: Literal["x", "y", "z"] = "x"
2425

2526

2627
class MyTopLevelModel(BaseModel, validate_assignment=True):
@@ -35,25 +36,30 @@ class MyTopLevelModel(BaseModel, validate_assignment=True):
3536
dict_arg_default_values: Dict[str, str] = {"existing-key": "existing-value"}
3637
path_arg: Path = Path(".")
3738

39+
list_literal: List[Literal["a", "b", "c"]] = ["a"]
40+
dict_literal_key: Dict[Literal["a", "b", "c"], str] = {"a": "first"}
41+
dict_literal_value: Dict[str, Literal["a", "b", "c"]] = {"first": "a"}
42+
3843
list_enum: List[MyEnum] = [MyEnum.OPTION_A]
3944
dict_enum: Dict[str, MyEnum] = {"first": MyEnum.OPTION_A}
4045
dict_enum_key: Dict[MyEnum, str] = {MyEnum.OPTION_A: "first"}
4146
dict_enum_key_model_value: Dict[MyEnum, SubModel] = {MyEnum.OPTION_A: SubModel()}
4247

4348
submodel: SubModel
44-
submodel2: SubModel = SubModel()
49+
submodel2: SubModel = SubModel(sub_args=84, sub_arg_with_value="predefined", sub_arg_enum=MyEnum.OPTION_B, sub_arg_literal="z")
4550
submodel3: Optional[SubModel] = None
4651

47-
submodel_list: List[SubModel] = []
4852
submodel_list_instanced: List[SubModel] = [SubModel()]
49-
submodel_dict: Dict[str, SubModel] = {}
5053
submodel_dict_instanced: Dict[str, SubModel] = {"a": SubModel()}
5154

5255
unsupported_literal: Literal[b"test"] = b"test"
5356
unsupported_dict: Dict[SubModel, str] = {}
5457
unsupported_dict_mixed_types: Dict[str, Union[str, SubModel]] = {}
5558
unsupported_random_type: Optional[set] = None
5659

60+
unsupported_submodel_list: List[SubModel] = []
61+
unsupported_submodel_dict: Dict[str, SubModel] = {}
62+
5763

5864
class TestCLIMdel:
5965
def test_get_arg_from_model(self):
@@ -81,6 +87,12 @@ def test_get_arg_from_model(self):
8187
"new-value",
8288
"--path-arg",
8389
"/some/path",
90+
"--list-literal",
91+
"a,b",
92+
"--dict-literal-key.a",
93+
"first",
94+
"--dict-literal-value.first",
95+
"a",
8496
"--list-enum",
8597
"option_a,option_b",
8698
"--dict-enum.first",
@@ -95,6 +107,10 @@ def test_get_arg_from_model(self):
95107
"100",
96108
"--submodel.sub-arg-with-value",
97109
"sub_value",
110+
"--submodel.sub-arg-enum",
111+
"option_a",
112+
"--submodel.sub-arg-literal",
113+
"y",
98114
"--submodel2.sub-arg",
99115
"200",
100116
"--submodel2.sub-arg-with-value",
@@ -103,6 +119,12 @@ def test_get_arg_from_model(self):
103119
"300",
104120
"--submodel-list-instanced.0.sub-arg",
105121
"400",
122+
"--submodel-list-instanced.0.sub-arg-with-value",
123+
"list_value",
124+
"--submodel-list-instanced.0.sub-arg-enum",
125+
"option_b",
126+
"--submodel-list-instanced.0.sub-arg-literal",
127+
"z",
106128
"--submodel-dict-instanced.a.sub-arg",
107129
"500",
108130
],
@@ -127,25 +149,37 @@ def test_get_arg_from_model(self):
127149
assert model.dict_arg_default_values == {"existing-key": "new-value"}
128150
assert model.path_arg == Path("/some/path")
129151

152+
assert model.list_literal == ["a", "b"]
153+
assert model.dict_literal_key == {"a": "first"}
154+
assert model.dict_literal_value == {"first": "a"}
155+
130156
assert model.list_enum == [MyEnum.OPTION_A, MyEnum.OPTION_B]
131157
assert model.dict_enum == {"first": MyEnum.OPTION_B}
132158
assert model.dict_enum_key == {MyEnum.OPTION_A: "first", MyEnum.OPTION_B: "second", MyEnum.OPTION_C: "third"}
133159
assert model.dict_enum_key_model_value[MyEnum.OPTION_A].sub_arg == 600
134160

135161
assert model.submodel.sub_arg == 100
136162
assert model.submodel.sub_arg_with_value == "sub_value"
163+
assert model.submodel.sub_arg_enum == MyEnum.OPTION_A
164+
assert model.submodel.sub_arg_literal == "y"
137165
assert model.submodel2.sub_arg == 200
138166
assert model.submodel2.sub_arg_with_value == "sub_value2"
167+
assert model.submodel2.sub_arg_enum == MyEnum.OPTION_B
168+
assert model.submodel2.sub_arg_literal == "z"
169+
139170
assert model.submodel3.sub_arg == 300
140171
assert model.submodel_list_instanced[0].sub_arg == 400
172+
assert model.submodel_list_instanced[0].sub_arg_with_value == "list_value"
173+
assert model.submodel_list_instanced[0].sub_arg_enum == MyEnum.OPTION_B
174+
assert model.submodel_list_instanced[0].sub_arg_literal == "z"
141175
assert model.submodel_dict_instanced["a"].sub_arg == 500
142176

143177
stderr = mock_stderr.getvalue()
144178
for text in (
145179
f"[sdist]\ndist/hatch_build-{__version__}.tar.gz",
146180
f"[wheel]\ndist/hatch_build-{__version__}-py3-none-any.whl",
147-
"[hatch_build.cli][WARNING]: Only lists of str, int, float, or bool are supported - field `submodel_list` got <class 'test_cli_model.SubModel'>",
148-
"[hatch_build.cli][WARNING]: Only dicts with str, int, float, bool, or enum values are supported - field `submodel_dict` got value type <class 'test_cli_model.SubModel'>",
181+
"[hatch_build.cli][WARNING]: Only lists of str, int, float, or bool are supported - field `unsupported_submodel_list` got <class 'test_cli_model.SubModel'>",
182+
"[hatch_build.cli][WARNING]: Only dicts with str, int, float, bool, or enum values are supported - field `unsupported_submodel_dict` got value type <class 'test_cli_model.SubModel'>",
149183
"[hatch_build.cli][WARNING]: Only Literal types of str, int, float, or bool are supported - field `unsupported_literal` got (b'test',)",
150184
"[hatch_build.cli][WARNING]: Only dicts with str, int, float, bool, or enum keys are supported - field `unsupported_dict` got key type <class 'test_cli_model.SubModel'>",
151185
"[hatch_build.cli][WARNING]: Only dicts with str, int, float, bool, or enum values are supported - field `unsupported_dict_mixed_types` got value type typing.Union[str, test_cli_model.SubModel]",

0 commit comments

Comments
 (0)