From 29ed5f99b4187119916d6ccc2068cf3accbb95f3 Mon Sep 17 00:00:00 2001 From: Tim Paine <3105306+timkpaine@users.noreply.github.com> Date: Thu, 6 Nov 2025 15:12:40 -0500 Subject: [PATCH] Add initial cut of model parser, fixes #8 --- hatch_build/cli.py | 83 ++++++++++++++++++++++++++++- hatch_build/tests/test_cli.py | 34 ++++++------ hatch_build/tests/test_cli_model.py | 65 ++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 166 insertions(+), 17 deletions(-) create mode 100644 hatch_build/tests/test_cli_model.py diff --git a/hatch_build/cli.py b/hatch_build/cli.py index b232766..ea46986 100644 --- a/hatch_build/cli.py +++ b/hatch_build/cli.py @@ -1,8 +1,11 @@ from argparse import ArgumentParser -from typing import Callable, List, Optional, Tuple +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, get_args, get_origin from hatchling.cli.build import build_command +if TYPE_CHECKING: + from pydantic import BaseModel + __all__ = ( "hatchling", "parse_extra_args", @@ -17,6 +20,84 @@ def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]: return vars(kwargs), extras +def recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str = ""): + from pydantic import BaseModel + + for field_name, field in model.__class__.model_fields.items(): + field_type = field.annotation + arg_name = f"--{prefix}{field_name.replace('_', '-')}" + if field_type is bool: + parser.add_argument(arg_name, action="store_true", default=field.default) + elif isinstance(field_type, Type) and issubclass(field_type, BaseModel): + # Nested model, add its fields with a prefix + recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.") + elif get_origin(field_type) in (list, List): + # TODO: if list arg is complex type, raise as not implemented for now + if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool): + raise NotImplementedError("Only lists of str, int, float, or bool are supported") + parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default))) + elif get_origin(field_type) in (dict, Dict): + # TODO: if key args are complex type, raise as not implemented for now + key_type, value_type = get_args(field_type) + if key_type not in (str, int, float, bool): + raise NotImplementedError("Only dicts with str keys are supported") + if value_type not in (str, int, float, bool): + raise NotImplementedError("Only dicts with str values are supported") + parser.add_argument(arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items())) + else: + try: + parser.add_argument(arg_name, type=field_type, default=field.default) + except TypeError: + # TODO: handle more complex types if needed + parser.add_argument(arg_name, type=str, default=field.default) + return parser + + +def parse_extra_args_model(model: "BaseModel"): + try: + from pydantic import TypeAdapter + except ImportError: + raise ImportError("pydantic is required to use parse_extra_args_model") + # Recursively parse fields from a pydantic model and its sub-models + # and create an argument parser to parse extra args + parser = ArgumentParser(prog="hatch-build-extras-model", allow_abbrev=False) + parser = recurse_add_fields(parser, model) + + # Parse the extra args and update the model + args, kwargs = parse_extra_args(parser) + for key, value in args.items(): + # Handle nested fields + if "." in key: + parts = key.split(".") + sub_model = model + for part in parts[:-1]: + model_to_set = getattr(sub_model, part) + key = parts[-1] + else: + model_to_set = model + + # Grab the field from the model class and make a type adapter + field = model_to_set.__class__.model_fields[key] + adapter = TypeAdapter(field.annotation) + + # Convert the value using the type adapter + if get_origin(field.annotation) in (list, List): + value = adapter.validate_python(value.split(",")) + elif get_origin(field.annotation) in (dict, Dict): + dict_items = value.split(",") + dict_value = {} + for item in dict_items: + k, v = item.split("=", 1) + dict_value[k] = v + value = adapter.validate_python(dict_value) + else: + value = adapter.validate_python(value) + + # Set the value on the model + setattr(model_to_set, key, value) + return model, kwargs + + def _hatchling_internal() -> Tuple[Optional[Callable], Optional[dict], List[str]]: parser = ArgumentParser(prog="hatch-build", allow_abbrev=False) subparsers = parser.add_subparsers() diff --git a/hatch_build/tests/test_cli.py b/hatch_build/tests/test_cli.py index cdb61cf..e74f8f2 100644 --- a/hatch_build/tests/test_cli.py +++ b/hatch_build/tests/test_cli.py @@ -49,22 +49,24 @@ def ok_extra_argv(): @pytest.fixture def get_arg(): - tmp_argv = sys.argv - sys.argv = [ - "hatch-build", - "--", - "--extra-arg", - "--extra-arg-with-value", - "value", - "--extra-arg-with-value-equals=value2", - "--extra-arg-not-in-parser", - ] - parser = ArgumentParser() - parser.add_argument("--extra-arg", action="store_true") - parser.add_argument("--extra-arg-with-value") - parser.add_argument("--extra-arg-with-value-equals") - yield parser - sys.argv = tmp_argv + with patch.object( + sys, + "argv", + [ + "hatch-build", + "--", + "--extra-arg", + "--extra-arg-with-value", + "value", + "--extra-arg-with-value-equals=value2", + "--extra-arg-not-in-parser", + ], + ): + parser = ArgumentParser() + parser.add_argument("--extra-arg", action="store_true") + parser.add_argument("--extra-arg-with-value") + parser.add_argument("--extra-arg-with-value-equals") + yield parser class TestHatchBuild: diff --git a/hatch_build/tests/test_cli_model.py b/hatch_build/tests/test_cli_model.py new file mode 100644 index 0000000..30bdb70 --- /dev/null +++ b/hatch_build/tests/test_cli_model.py @@ -0,0 +1,65 @@ +import sys +from typing import Dict, List +from unittest.mock import patch + +from pydantic import BaseModel + +from hatch_build.cli import hatchling, parse_extra_args_model + + +class SubModel(BaseModel, validate_assignment=True): + sub_arg: int = 42 + sub_arg_with_value: str = "sub_default" + + +class MyTopLevelModel(BaseModel, validate_assignment=True): + extra_arg: bool = False + extra_arg_with_value: str = "default" + extra_arg_with_value_equals: str = "default_equals" + + list_arg: List[int] = [1, 2, 3] + dict_arg: Dict[str, str] = {"key": "value"} + + submodel: SubModel + submodel2: SubModel = SubModel() + + +class TestCLIMdel: + def test_get_arg_from_model(self): + with patch.object( + sys, + "argv", + [ + "hatch-build", + "--", + "--extra-arg", + "--extra-arg-with-value", + "value", + "--extra-arg-with-value-equals=value2", + "--extra-arg-not-in-parser", + "--list-arg", + "1,2,3", + "--dict-arg", + "key1=value1,key2=value2", + "--submodel.sub-arg", + "100", + "--submodel.sub-arg-with-value", + "sub_value", + "--submodel2.sub-arg", + "200", + "--submodel2.sub-arg-with-value", + "sub_value2", + ], + ): + assert hatchling() == 0 + model, extras = parse_extra_args_model(MyTopLevelModel(submodel=SubModel())) + + assert model.extra_arg is True + assert model.extra_arg_with_value == "value" + assert model.extra_arg_with_value_equals == "value2" + assert model.list_arg == [1, 2, 3] + assert model.dict_arg == {"key1": "value1", "key2": "value2"} + assert model.submodel.sub_arg == 100 + assert model.submodel.sub_arg_with_value == "sub_value" + assert model.submodel2.sub_arg == 200 + assert model.submodel2.sub_arg_with_value == "sub_value2" diff --git a/pyproject.toml b/pyproject.toml index 876b697..e93b5ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,7 @@ develop = [ "hatchling", "mdformat>=0.7.22,<1.1", "mdformat-tables>=1", + "pydantic>=2,<3", "pytest", "pytest-cov", "ruff>=0.9,<0.15",