Skip to content

Commit 739dbed

Browse files
authored
Merge pull request #12 from python-project-templates/tkp/hf
Instantiate missing models, support paths
2 parents d3175a3 + b3504c1 commit 739dbed

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

hatch_build/cli.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from argparse import ArgumentParser
22
from logging import getLogger
3-
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, get_args, get_origin
3+
from pathlib import Path
4+
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, Union, get_args, get_origin
45

56
from hatchling.cli.build import build_command
67

@@ -24,12 +25,31 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
2425
return vars(kwargs), extras
2526

2627

27-
def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""):
28+
def _recurse_add_fields(parser: ArgumentParser, model: Union["BaseModel", Type["BaseModel"]], prefix: str = ""):
2829
from pydantic import BaseModel
2930

30-
for field_name, field in model.__class__.model_fields.items():
31+
if model is None:
32+
raise ValueError("Model instance cannot be None")
33+
if isinstance(model, type):
34+
model_fields = model.model_fields
35+
else:
36+
model_fields = model.__class__.model_fields
37+
for field_name, field in model_fields.items():
3138
field_type = field.annotation
3239
arg_name = f"--{prefix}{field_name.replace('_', '-')}"
40+
41+
# Wrappers
42+
if get_origin(field_type) is Optional:
43+
field_type = get_args(field_type)[0]
44+
elif get_origin(field_type) is Union:
45+
non_none_types = [t for t in get_args(field_type) if t is not type(None)]
46+
if len(non_none_types) == 1:
47+
field_type = non_none_types[0]
48+
else:
49+
_log.warning(f"Unsupported Union type for argument '{field_name}': {field_type}")
50+
continue
51+
52+
# Handled types
3353
if field_type is bool:
3454
parser.add_argument(arg_name, action="store_true", default=field.default)
3555
elif field_type in (str, int, float):
@@ -38,9 +58,12 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
3858
except TypeError:
3959
# TODO: handle more complex types if needed
4060
parser.add_argument(arg_name, type=str, default=field.default)
61+
elif isinstance(field_type, type) and issubclass(field_type, Path):
62+
# Promote to/from string
63+
parser.add_argument(arg_name, type=str, default=str(field.default) if isinstance(field.default, Path) else None)
4164
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
4265
# Nested model, add its fields with a prefix
43-
_recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.")
66+
_recurse_add_fields(parser, field_type, prefix=f"{field_name}.")
4467
elif get_origin(field_type) is Literal:
4568
literal_args = get_args(field_type)
4669
if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args):
@@ -65,13 +88,13 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
6588
arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None
6689
)
6790
else:
68-
_log.warning(f"Unsupported field type for argument '{arg_name}': {field_type}")
91+
_log.warning(f"Unsupported field type for argument '{field_name}': {field_type}")
6992
return parser
7093

7194

7295
def parse_extra_args_model(model: "BaseModel"):
7396
try:
74-
from pydantic import TypeAdapter
97+
from pydantic import BaseModel, TypeAdapter
7598
except ImportError:
7699
raise ImportError("pydantic is required to use parse_extra_args_model")
77100
# Recursively parse fields from a pydantic model and its sub-models
@@ -88,6 +111,24 @@ def parse_extra_args_model(model: "BaseModel"):
88111
sub_model = model
89112
for part in parts[:-1]:
90113
model_to_set = getattr(sub_model, part)
114+
if model_to_set is None:
115+
# Create a new instance of model
116+
field = sub_model.__class__.model_fields[part]
117+
# if field annotation is an optional or union with none, extrat type
118+
if get_origin(field.annotation) is Optional:
119+
model_to_instance = get_args(field.annotation)[0]
120+
elif get_origin(field.annotation) is Union:
121+
non_none_types = [t for t in get_args(field.annotation) if t is not type(None)]
122+
if len(non_none_types) == 1:
123+
model_to_instance = non_none_types[0]
124+
else:
125+
model_to_instance = field.annotation
126+
if not isinstance(model_to_instance, type) or not issubclass(model_to_instance, BaseModel):
127+
raise ValueError(
128+
f"Cannot create sub-model for field '{part}' on model '{sub_model.__class__.__name__}': - type is {model_to_instance}"
129+
)
130+
model_to_set = model_to_instance()
131+
setattr(sub_model, part, model_to_set)
91132
key = parts[-1]
92133
else:
93134
model_to_set = model

hatch_build/tests/test_cli_model.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import sys
2-
from typing import Dict, List, Literal
2+
from pathlib import Path
3+
from typing import Dict, List, Literal, Optional
34
from unittest.mock import patch
45

56
from pydantic import BaseModel
@@ -15,14 +16,16 @@ class SubModel(BaseModel, validate_assignment=True):
1516
class MyTopLevelModel(BaseModel, validate_assignment=True):
1617
extra_arg: bool = False
1718
extra_arg_with_value: str = "default"
18-
extra_arg_with_value_equals: str = "default_equals"
19+
extra_arg_with_value_equals: Optional[str] = "default_equals"
1920
extra_arg_literal: Literal["a", "b", "c"] = "a"
2021

2122
list_arg: List[int] = [1, 2, 3]
2223
dict_arg: Dict[str, str] = {"key": "value"}
24+
path_arg: Path = Path(".")
2325

2426
submodel: SubModel
2527
submodel2: SubModel = SubModel()
28+
submodel3: Optional[SubModel] = None
2629

2730

2831
class TestCLIMdel:
@@ -44,6 +47,8 @@ def test_get_arg_from_model(self):
4447
"1,2,3",
4548
"--dict-arg",
4649
"key1=value1,key2=value2",
50+
"--path-arg",
51+
"/some/path",
4752
"--submodel.sub-arg",
4853
"100",
4954
"--submodel.sub-arg-with-value",
@@ -52,6 +57,8 @@ def test_get_arg_from_model(self):
5257
"200",
5358
"--submodel2.sub-arg-with-value",
5459
"sub_value2",
60+
"--submodel3.sub-arg",
61+
"300",
5562
],
5663
):
5764
assert hatchling() == 0
@@ -63,9 +70,11 @@ def test_get_arg_from_model(self):
6370
assert model.extra_arg_literal == "b"
6471
assert model.list_arg == [1, 2, 3]
6572
assert model.dict_arg == {"key1": "value1", "key2": "value2"}
73+
assert model.path_arg == Path("/some/path")
6674
assert model.submodel.sub_arg == 100
6775
assert model.submodel.sub_arg_with_value == "sub_value"
6876
assert model.submodel2.sub_arg == 200
6977
assert model.submodel2.sub_arg_with_value == "sub_value2"
78+
assert model.submodel3.sub_arg == 300
7079

7180
assert "--extra-arg-not-in-parser" in extras

0 commit comments

Comments
 (0)