Skip to content

Commit fe51315

Browse files
committed
support extra value checking
1 parent e55a4bd commit fe51315

File tree

2 files changed

+15
-4
lines changed

2 files changed

+15
-4
lines changed

dargs/dargs.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self,
4040
optional: bool = False,
4141
default: Any = _Flags.NONE,
4242
alias: Optional[Iterable[str]] = None,
43+
extra_check: Optional[Callable[[Any], bool]] = None,
4344
doc: str = ""):
4445
self.name = name
4546
self.dtype = dtype
@@ -49,6 +50,7 @@ def __init__(self,
4950
self.optional = optional
5051
self.default = default
5152
self.alias = alias if alias is not None else []
53+
self.extra_check = extra_check
5254
self.doc = doc
5355
# handle the format of dtype, makeit a tuple
5456
self.reorg_dtype()
@@ -164,13 +166,13 @@ def check(self, argdict: dict, strict: bool = False):
164166
"use check_value if you are checking subfields")
165167
self.traverse(argdict,
166168
key_hook=Argument._check_exist,
167-
value_hook=Argument._check_dtype,
169+
value_hook=Argument._check_value,
168170
sub_hook=Argument._check_strict if strict else DUMMYHOOK)
169171

170172
def check_value(self, argdict: dict, strict: bool = False):
171173
self.traverse_value(argdict,
172174
key_hook=Argument._check_exist,
173-
value_hook=Argument._check_dtype,
175+
value_hook=Argument._check_value,
174176
sub_hook=Argument._check_strict if strict else DUMMYHOOK)
175177

176178
def _check_exist(self, argdict: dict):
@@ -180,10 +182,13 @@ def _check_exist(self, argdict: dict):
180182
raise KeyError(f"key `{self.name}` is required "
181183
"in arguments but not found")
182184

183-
def _check_dtype(self, value: Any):
185+
def _check_value(self, value: Any):
184186
if not isinstance(value, self.dtype):
185187
raise TypeError(f"key `{self.name}` gets wrong value type: "
186188
f"requires {self.dtype} but gets {type(value)}")
189+
if self.extra_check is not None and not self.extra_check(value):
190+
raise ValueError(f"key `{self.name}` gets bad value "
191+
"that fails to pass its extra checking")
187192

188193
def _check_strict(self, value: dict):
189194
allowed = self._get_allowed_sub(value)

tests/test_checker.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,14 @@ def test_name_type(self):
2525
ca = Argument("key1", [int, None])
2626
ca.check({"key1": None})
2727
# optional case
28-
ca = Argument("Key1", int, optional=True)
28+
ca = Argument("key1", int, optional=True)
2929
ca.check({})
30+
# extra checker
31+
ca = Argument("key1", int, extra_check=lambda v: v > 0)
32+
ca.check({"key1": 1})
33+
with self.assertRaises(ValueError):
34+
ca.check({"key1": 0})
35+
3036

3137
def test_sub_fields(self):
3238
ca = Argument("base", dict, [

0 commit comments

Comments
 (0)