Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 82 additions & 1 deletion hatch_build/cli.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from argparse import ArgumentParser
from typing import Callable, List, Optional, Tuple
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, get_args, get_origin

from hatchling.cli.build import build_command

if TYPE_CHECKING:
from pydantic import BaseModel

__all__ = (
"hatchling",
"parse_extra_args",
Expand All @@ -17,6 +20,84 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
return vars(kwargs), extras


def recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""):
from pydantic import BaseModel

for field_name, field in model.__class__.model_fields.items():
field_type = field.annotation
arg_name = f"--{prefix}{field_name.replace('_', '-')}"
if field_type is bool:
parser.add_argument(arg_name, action="store_true", default=field.default)
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
# Nested model, add its fields with a prefix
recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.")
elif get_origin(field_type) in (list, List):
# TODO: if list arg is complex type, raise as not implemented for now
if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool):
raise NotImplementedError("Only lists of str, int, float, or bool are supported")
parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)))
elif get_origin(field_type) in (dict, Dict):
# TODO: if key args are complex type, raise as not implemented for now
key_type, value_type = get_args(field_type)
if key_type not in (str, int, float, bool):
raise NotImplementedError("Only dicts with str keys are supported")
if value_type not in (str, int, float, bool):
raise NotImplementedError("Only dicts with str values are supported")
parser.add_argument(arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()))
else:
try:
parser.add_argument(arg_name, type=field_type, default=field.default)
except TypeError:
# TODO: handle more complex types if needed
parser.add_argument(arg_name, type=str, default=field.default)
return parser


def parse_extra_args_model(model: "BaseModel"):
try:
from pydantic import TypeAdapter
except ImportError:
raise ImportError("pydantic is required to use parse_extra_args_model")
# Recursively parse fields from a pydantic model and its sub-models
# and create an argument parser to parse extra args
parser = ArgumentParser(prog="hatch-build-extras-model", allow_abbrev=False)
parser = recurse_add_fields(parser, model)

# Parse the extra args and update the model
args, kwargs = parse_extra_args(parser)
for key, value in args.items():
# Handle nested fields
if "." in key:
parts = key.split(".")
sub_model = model
for part in parts[:-1]:
model_to_set = getattr(sub_model, part)
key = parts[-1]
else:
model_to_set = model

# Grab the field from the model class and make a type adapter
field = model_to_set.__class__.model_fields[key]
adapter = TypeAdapter(field.annotation)

# Convert the value using the type adapter
if get_origin(field.annotation) in (list, List):
value = adapter.validate_python(value.split(","))
elif get_origin(field.annotation) in (dict, Dict):
dict_items = value.split(",")
dict_value = {}
for item in dict_items:
k, v = item.split("=", 1)
dict_value[k] = v
value = adapter.validate_python(dict_value)
else:
value = adapter.validate_python(value)

# Set the value on the model
setattr(model_to_set, key, value)
return model, kwargs


def _hatchling_internal() -> Tuple[Optional[Callable], Optional[dict], List[str]]:
parser = ArgumentParser(prog="hatch-build", allow_abbrev=False)
subparsers = parser.add_subparsers()
Expand Down
34 changes: 18 additions & 16 deletions hatch_build/tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,24 @@ def ok_extra_argv():

@pytest.fixture
def get_arg():
tmp_argv = sys.argv
sys.argv = [
"hatch-build",
"--",
"--extra-arg",
"--extra-arg-with-value",
"value",
"--extra-arg-with-value-equals=value2",
"--extra-arg-not-in-parser",
]
parser = ArgumentParser()
parser.add_argument("--extra-arg", action="store_true")
parser.add_argument("--extra-arg-with-value")
parser.add_argument("--extra-arg-with-value-equals")
yield parser
sys.argv = tmp_argv
with patch.object(
sys,
"argv",
[
"hatch-build",
"--",
"--extra-arg",
"--extra-arg-with-value",
"value",
"--extra-arg-with-value-equals=value2",
"--extra-arg-not-in-parser",
],
):
parser = ArgumentParser()
parser.add_argument("--extra-arg", action="store_true")
parser.add_argument("--extra-arg-with-value")
parser.add_argument("--extra-arg-with-value-equals")
yield parser


class TestHatchBuild:
Expand Down
65 changes: 65 additions & 0 deletions hatch_build/tests/test_cli_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import sys
from typing import Dict, List
from unittest.mock import patch

from pydantic import BaseModel

from hatch_build.cli import hatchling, parse_extra_args_model


class SubModel(BaseModel, validate_assignment=True):
sub_arg: int = 42
sub_arg_with_value: str = "sub_default"


class MyTopLevelModel(BaseModel, validate_assignment=True):
extra_arg: bool = False
extra_arg_with_value: str = "default"
extra_arg_with_value_equals: str = "default_equals"

list_arg: List[int] = [1, 2, 3]
dict_arg: Dict[str, str] = {"key": "value"}

submodel: SubModel
submodel2: SubModel = SubModel()


class TestCLIMdel:
def test_get_arg_from_model(self):
with patch.object(
sys,
"argv",
[
"hatch-build",
"--",
"--extra-arg",
"--extra-arg-with-value",
"value",
"--extra-arg-with-value-equals=value2",
"--extra-arg-not-in-parser",
"--list-arg",
"1,2,3",
"--dict-arg",
"key1=value1,key2=value2",
"--submodel.sub-arg",
"100",
"--submodel.sub-arg-with-value",
"sub_value",
"--submodel2.sub-arg",
"200",
"--submodel2.sub-arg-with-value",
"sub_value2",
],
):
assert hatchling() == 0
model, extras = parse_extra_args_model(MyTopLevelModel(submodel=SubModel()))

assert model.extra_arg is True
assert model.extra_arg_with_value == "value"
assert model.extra_arg_with_value_equals == "value2"
assert model.list_arg == [1, 2, 3]
assert model.dict_arg == {"key1": "value1", "key2": "value2"}
assert model.submodel.sub_arg == 100
assert model.submodel.sub_arg_with_value == "sub_value"
assert model.submodel2.sub_arg == 200
assert model.submodel2.sub_arg_with_value == "sub_value2"
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ develop = [
"hatchling",
"mdformat>=0.7.22,<1.1",
"mdformat-tables>=1",
"pydantic>=2,<3",
"pytest",
"pytest-cov",
"ruff>=0.9,<0.15",
Expand Down