Skip to content

Commit d3175a3

Browse files
authored
Merge pull request #11 from python-project-templates/tkp/hf
Support literal strings
2 parents 37f653f + 64ca16c commit d3175a3

File tree

2 files changed

+35
-14
lines changed

2 files changed

+35
-14
lines changed

hatch_build/cli.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from argparse import ArgumentParser
2-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Type, get_args, get_origin
2+
from logging import getLogger
3+
from typing import TYPE_CHECKING, Callable, Dict, List, Literal, Optional, Tuple, Type, get_args, get_origin
34

45
from hatchling.cli.build import build_command
56

@@ -13,6 +14,8 @@
1314
)
1415
_extras = None
1516

17+
_log = getLogger(__name__)
18+
1619

1720
def parse_extra_args(subparser: Optional[ArgumentParser] = None) -> List[str]:
1821
if subparser is None:
@@ -29,28 +32,40 @@ def _recurse_add_fields(parser: ArgumentParser, model: "BaseModel", prefix: str
2932
arg_name = f"--{prefix}{field_name.replace('_', '-')}"
3033
if field_type is bool:
3134
parser.add_argument(arg_name, action="store_true", default=field.default)
35+
elif field_type in (str, int, float):
36+
try:
37+
parser.add_argument(arg_name, type=field_type, default=field.default)
38+
except TypeError:
39+
# TODO: handle more complex types if needed
40+
parser.add_argument(arg_name, type=str, default=field.default)
3241
elif isinstance(field_type, Type) and issubclass(field_type, BaseModel):
3342
# Nested model, add its fields with a prefix
3443
_recurse_add_fields(parser, getattr(model, field_name), prefix=f"{field_name}.")
44+
elif get_origin(field_type) is Literal:
45+
literal_args = get_args(field_type)
46+
if not all(isinstance(arg, (str, int, float, bool)) for arg in literal_args):
47+
_log.warning(f"Only Literal types of str, int, float, or bool are supported - got {literal_args}")
48+
else:
49+
parser.add_argument(arg_name, type=type(literal_args[0]), choices=literal_args, default=field.default)
3550
elif get_origin(field_type) in (list, List):
36-
# TODO: if list arg is complex type, raise as not implemented for now
51+
# TODO: if list arg is complex type, warn as not implemented for now
3752
if get_args(field_type) and get_args(field_type)[0] not in (str, int, float, bool):
38-
raise NotImplementedError("Only lists of str, int, float, or bool are supported")
39-
parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)))
53+
_log.warning(f"Only lists of str, int, float, or bool are supported - got {get_args(field_type)[0]}")
54+
else:
55+
parser.add_argument(arg_name, type=str, default=",".join(map(str, field.default)) if isinstance(field, str) else None)
4056
elif get_origin(field_type) in (dict, Dict):
41-
# TODO: if key args are complex type, raise as not implemented for now
57+
# TODO: if key args are complex type, warn as not implemented for now
4258
key_type, value_type = get_args(field_type)
4359
if key_type not in (str, int, float, bool):
44-
raise NotImplementedError("Only dicts with str keys are supported")
60+
_log.warning(f"Only dicts with str keys are supported - got key type {key_type}")
4561
if value_type not in (str, int, float, bool):
46-
raise NotImplementedError("Only dicts with str values are supported")
47-
parser.add_argument(arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()))
62+
_log.warning(f"Only dicts with str values are supported - got value type {value_type}")
63+
else:
64+
parser.add_argument(
65+
arg_name, type=str, default=",".join(f"{k}={v}" for k, v in field.default.items()) if isinstance(field.default, dict) else None
66+
)
4867
else:
49-
try:
50-
parser.add_argument(arg_name, type=field_type, default=field.default)
51-
except TypeError:
52-
# TODO: handle more complex types if needed
53-
parser.add_argument(arg_name, type=str, default=field.default)
68+
_log.warning(f"Unsupported field type for argument '{arg_name}': {field_type}")
5469
return parser
5570

5671

hatch_build/tests/test_cli_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import sys
2-
from typing import Dict, List
2+
from typing import Dict, List, Literal
33
from unittest.mock import patch
44

55
from pydantic import BaseModel
@@ -16,6 +16,7 @@ class MyTopLevelModel(BaseModel, validate_assignment=True):
1616
extra_arg: bool = False
1717
extra_arg_with_value: str = "default"
1818
extra_arg_with_value_equals: str = "default_equals"
19+
extra_arg_literal: Literal["a", "b", "c"] = "a"
1920

2021
list_arg: List[int] = [1, 2, 3]
2122
dict_arg: Dict[str, str] = {"key": "value"}
@@ -37,6 +38,8 @@ def test_get_arg_from_model(self):
3738
"value",
3839
"--extra-arg-with-value-equals=value2",
3940
"--extra-arg-not-in-parser",
41+
"--extra-arg-literal",
42+
"b",
4043
"--list-arg",
4144
"1,2,3",
4245
"--dict-arg",
@@ -57,9 +60,12 @@ def test_get_arg_from_model(self):
5760
assert model.extra_arg is True
5861
assert model.extra_arg_with_value == "value"
5962
assert model.extra_arg_with_value_equals == "value2"
63+
assert model.extra_arg_literal == "b"
6064
assert model.list_arg == [1, 2, 3]
6165
assert model.dict_arg == {"key1": "value1", "key2": "value2"}
6266
assert model.submodel.sub_arg == 100
6367
assert model.submodel.sub_arg_with_value == "sub_value"
6468
assert model.submodel2.sub_arg == 200
6569
assert model.submodel2.sub_arg_with_value == "sub_value2"
70+
71+
assert "--extra-arg-not-in-parser" in extras

0 commit comments

Comments
 (0)