Skip to content

Commit 29ed5f9

Browse files
committed
Add initial cut of model parser, fixes #8
1 parent 8460861 commit 29ed5f9

File tree

4 files changed

+166
-17
lines changed

4 files changed

+166
-17
lines changed

hatch_build/cli.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
from argparse import ArgumentParser
2-
from typing import Callable, List, Optional, Tuple
2+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, get_args, get_origin
33

44
from hatchling.cli.build import build_command
55

6+
if TYPE_CHECKING:
7+
from pydantic import BaseModel
8+
69
__all__ = (
710
"hatchling",
811
"parse_extra_args",
@@ -17,6 +20,84 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
1720
return vars(kwargs), extras
1821

1922

23+
def recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""):
24+
from pydantic import BaseModel
25+
26+
for field_name, field in model.__class__.model_fields.items():
27+
field_type = field.annotation
28+
arg_name = f"--{prefix}{field_name.replace('_', '-')}"
29+
if field_type is bool:
30+
parser.add_argument(arg_name, action="store_true", default=field.default)
31+
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
32+
# Nested model, add its fields with a prefix
33+
recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.")
34+
elif get_origin(field_type) in (list, List):
35+
# TODO: if list arg is complex type, raise as not implemented for now
36+
if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool):
37+
raise NotImplementedError("Only lists of str, int, float, or bool are supported")
38+
parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)))
39+
elif get_origin(field_type) in (dict, Dict):
40+
# TODO: if key args are complex type, raise as not implemented for now
41+
key_type, value_type = get_args(field_type)
42+
if key_type not in (str, int, float, bool):
43+
raise NotImplementedError("Only dicts with str keys are supported")
44+
if value_type not in (str, int, float, bool):
45+
raise NotImplementedError("Only dicts with str values are supported")
46+
parser.add_argument(arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()))
47+
else:
48+
try:
49+
parser.add_argument(arg_name, type=field_type, default=field.default)
50+
except TypeError:
51+
# TODO: handle more complex types if needed
52+
parser.add_argument(arg_name, type=str, default=field.default)
53+
return parser
54+
55+
56+
def parse_extra_args_model(model: "BaseModel"):
57+
try:
58+
from pydantic import TypeAdapter
59+
except ImportError:
60+
raise ImportError("pydantic is required to use parse_extra_args_model")
61+
# Recursively parse fields from a pydantic model and its sub-models
62+
# and create an argument parser to parse extra args
63+
parser = ArgumentParser(prog="hatch-build-extras-model", allow_abbrev=False)
64+
parser = recurse_add_fields(parser, model)
65+
66+
# Parse the extra args and update the model
67+
args, kwargs = parse_extra_args(parser)
68+
for key, value in args.items():
69+
# Handle nested fields
70+
if "." in key:
71+
parts = key.split(".")
72+
sub_model = model
73+
for part in parts[:-1]:
74+
model_to_set = getattr(sub_model, part)
75+
key = parts[-1]
76+
else:
77+
model_to_set = model
78+
79+
# Grab the field from the model class and make a type adapter
80+
field = model_to_set.__class__.model_fields[key]
81+
adapter = TypeAdapter(field.annotation)
82+
83+
# Convert the value using the type adapter
84+
if get_origin(field.annotation) in (list, List):
85+
value = adapter.validate_python(value.split(","))
86+
elif get_origin(field.annotation) in (dict, Dict):
87+
dict_items = value.split(",")
88+
dict_value = {}
89+
for item in dict_items:
90+
k, v = item.split("=", 1)
91+
dict_value[k] = v
92+
value = adapter.validate_python(dict_value)
93+
else:
94+
value = adapter.validate_python(value)
95+
96+
# Set the value on the model
97+
setattr(model_to_set, key, value)
98+
return model, kwargs
99+
100+
20101
def _hatchling_internal() -> Tuple[Optional[Callable], Optional[dict], List[str]]:
21102
parser = ArgumentParser(prog="hatch-build", allow_abbrev=False)
22103
subparsers = parser.add_subparsers()

hatch_build/tests/test_cli.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,24 @@ def ok_extra_argv():
4949

5050
@pytest.fixture
5151
def get_arg():
52-
tmp_argv = sys.argv
53-
sys.argv = [
54-
"hatch-build",
55-
"--",
56-
"--extra-arg",
57-
"--extra-arg-with-value",
58-
"value",
59-
"--extra-arg-with-value-equals=value2",
60-
"--extra-arg-not-in-parser",
61-
]
62-
parser = ArgumentParser()
63-
parser.add_argument("--extra-arg", action="store_true")
64-
parser.add_argument("--extra-arg-with-value")
65-
parser.add_argument("--extra-arg-with-value-equals")
66-
yield parser
67-
sys.argv = tmp_argv
52+
with patch.object(
53+
sys,
54+
"argv",
55+
[
56+
"hatch-build",
57+
"--",
58+
"--extra-arg",
59+
"--extra-arg-with-value",
60+
"value",
61+
"--extra-arg-with-value-equals=value2",
62+
"--extra-arg-not-in-parser",
63+
],
64+
):
65+
parser = ArgumentParser()
66+
parser.add_argument("--extra-arg", action="store_true")
67+
parser.add_argument("--extra-arg-with-value")
68+
parser.add_argument("--extra-arg-with-value-equals")
69+
yield parser
6870

6971

7072
class TestHatchBuild:
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import sys
2+
from typing import Dict, List
3+
from unittest.mock import patch
4+
5+
from pydantic import BaseModel
6+
7+
from hatch_build.cli import hatchling, parse_extra_args_model
8+
9+
10+
class SubModel(BaseModel, validate_assignment=True):
11+
sub_arg: int = 42
12+
sub_arg_with_value: str = "sub_default"
13+
14+
15+
class MyTopLevelModel(BaseModel, validate_assignment=True):
16+
extra_arg: bool = False
17+
extra_arg_with_value: str = "default"
18+
extra_arg_with_value_equals: str = "default_equals"
19+
20+
list_arg: List[int] = [1, 2, 3]
21+
dict_arg: Dict[str, str] = {"key": "value"}
22+
23+
submodel: SubModel
24+
submodel2: SubModel = SubModel()
25+
26+
27+
class TestCLIMdel:
28+
def test_get_arg_from_model(self):
29+
with patch.object(
30+
sys,
31+
"argv",
32+
[
33+
"hatch-build",
34+
"--",
35+
"--extra-arg",
36+
"--extra-arg-with-value",
37+
"value",
38+
"--extra-arg-with-value-equals=value2",
39+
"--extra-arg-not-in-parser",
40+
"--list-arg",
41+
"1,2,3",
42+
"--dict-arg",
43+
"key1=value1,key2=value2",
44+
"--submodel.sub-arg",
45+
"100",
46+
"--submodel.sub-arg-with-value",
47+
"sub_value",
48+
"--submodel2.sub-arg",
49+
"200",
50+
"--submodel2.sub-arg-with-value",
51+
"sub_value2",
52+
],
53+
):
54+
assert hatchling() == 0
55+
model, extras = parse_extra_args_model(MyTopLevelModel(submodel=SubModel()))
56+
57+
assert model.extra_arg is True
58+
assert model.extra_arg_with_value == "value"
59+
assert model.extra_arg_with_value_equals == "value2"
60+
assert model.list_arg == [1, 2, 3]
61+
assert model.dict_arg == {"key1": "value1", "key2": "value2"}
62+
assert model.submodel.sub_arg == 100
63+
assert model.submodel.sub_arg_with_value == "sub_value"
64+
assert model.submodel2.sub_arg == 200
65+
assert model.submodel2.sub_arg_with_value == "sub_value2"

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ develop = [
4343
"hatchling",
4444
"mdformat>=0.7.22,<1.1",
4545
"mdformat-tables>=1",
46+
"pydantic>=2,<3",
4647
"pytest",
4748
"pytest-cov",
4849
"ruff>=0.9,<0.15",

0 commit comments

Comments
 (0)