Skip to content

Commit 83f3cdc

Browse files
committed
add trim unrequired keys
1 parent 483bda8 commit 83f3cdc

File tree

3 files changed

+61
-11
lines changed

3 files changed

+61
-11
lines changed

dargs/dargs.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from typing import Union, Any, List, Iterable, Optional, Callable
2222
from textwrap import wrap, fill, indent
2323
from copy import deepcopy
24+
import fnmatch, re
2425

2526

2627
INDENT = " " # doc is indented by four spaces
@@ -149,6 +150,10 @@ def _traverse_subvariant(self, value: dict, *args, **kwargs):
149150
# below are type checking part
150151

151152
def check(self, argdict: dict, strict: bool = False):
153+
if strict and len(argdict) != 1:
154+
raise KeyError("only one single key of arg name is allowed "
155+
"for check in strict mode at top level, "
156+
"use check_value if you are checking subfields")
152157
self.traverse(argdict,
153158
key_hook=Argument._check_exist,
154159
value_hook=Argument._check_dtype,
@@ -170,7 +175,7 @@ def _check_exist(self, argdict: dict):
170175
def _check_dtype(self, value: Any):
171176
if not isinstance(value, self.dtype):
172177
raise TypeError(f"key `{self.name}` gets wrong value type: "
173-
f"requires: {self.dtype} but gets {type(value)}")
178+
f"requires {self.dtype} but gets {type(value)}")
174179

175180
def _check_strict(self, value: dict):
176181
allowed = self._get_allowed_sub(value)
@@ -181,7 +186,7 @@ def _check_strict(self, value: dict):
181186
raise KeyError(f"undefined key `{name}` is "
182187
"not allowed in strict mode")
183188

184-
def _get_allowed_sub(self, value: dict):
189+
def _get_allowed_sub(self, value: dict) -> List[str]:
185190
allowed = [subarg.name for subarg in self.sub_fields]
186191
for subvrnt in self.sub_variants:
187192
allowed.extend(subvrnt._get_allowed_sub(value))
@@ -191,23 +196,32 @@ def _get_allowed_sub(self, value: dict):
191196
# below are normalizing part
192197

193198
def normalize(self, argdict: dict, inplace: bool = False,
194-
do_default: bool = True, do_alias: bool = True):
199+
do_default: bool = True, do_alias: bool = True,
200+
trim_pattern: Optional[str] = None):
195201
if not inplace:
196202
argdict = deepcopy(argdict)
197203
if do_alias:
198204
self.traverse(argdict, key_hook=Argument._convert_alias)
199205
if do_default:
200206
self.traverse(argdict, key_hook=Argument._assign_default)
207+
if trim_pattern is not None:
208+
self._trim_unrequired(argdict, trim_pattern, reserved=[self.name])
209+
self.traverse(argdict, sub_hook=lambda a, d:
210+
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
201211
return argdict
202212

203213
def normalize_value(self, value: Any, inplace: bool = False,
204-
do_default: bool = True, do_alias: bool = True):
214+
do_default: bool = True, do_alias: bool = True,
215+
trim_pattern: Optional[str] = None):
205216
if not inplace:
206217
value = deepcopy(value)
207218
if do_alias:
208219
self.traverse_value(value, key_hook=Argument._convert_alias)
209220
if do_default:
210221
self.traverse_value(value, key_hook=Argument._assign_default)
222+
if trim_pattern is not None:
223+
self.traverse_value(value, sub_hook=lambda a, d:
224+
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
211225
return value
212226

213227
def _assign_default(self, argdict: dict):
@@ -221,6 +235,21 @@ def _convert_alias(self, argdict: dict):
221235
argdict[self.name] = argdict.pop(alias)
222236
return
223237

238+
@staticmethod
239+
def _trim_unrequired(argdict: dict, pattern: str,
240+
reserved: Optional[List[str]] = None,
241+
use_regex: bool = False):
242+
rep = fnmatch.translate(pattern) if not use_regex else pattern
243+
rem = re.compile(rep)
244+
if reserved:
245+
conflict = list(filter(rem.match, reserved))
246+
if conflict:
247+
raise ValueError(f"pattern `{pattern}` conflicts with the "
248+
f"following reserved names: {', '.join(conflict)}")
249+
unrequired = list(filter(rem.match, argdict.keys()))
250+
for key in unrequired:
251+
argdict.pop(key)
252+
224253
# above are normalizing part
225254
# below are doc generation part
226255

tests/test_checker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ def test_name_type(self):
1717
# possible error
1818
with self.assertRaises(KeyError):
1919
ca.check({"key2": 1})
20+
with self.assertRaises(KeyError):
21+
ca.check({"key1": 1, "key2": 1}, strict=True)
2022
with self.assertRaises(TypeError):
2123
ca.check({"key1": 1.0})
2224
# special handle of None

tests/test_normalization.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,22 @@ def test_alias(self):
3434
self.assertDictEqual(end1, ref)
3535
self.assertTrue(end1 is beg1)
3636

37+
def test_trim(self):
38+
ca = Argument("Key1", int)
39+
beg = {"Key1": 1, "_comment": 123}
40+
end = ca.normalize(beg, trim_pattern="_*")
41+
ref = {"Key1": 1}
42+
self.assertTrue(end == ref)
43+
self.assertTrue(beg == {"Key1": 1, "_comment": 123})
44+
self.assertTrue(end is not beg)
45+
# conflict pattern
46+
with self.assertRaises(ValueError):
47+
ca.normalize(beg, trim_pattern="Key1")
48+
# inplace
49+
end1 = ca.normalize(beg, inplace=True, trim_pattern="_*")
50+
self.assertTrue(end1 == ref)
51+
self.assertTrue(end1 is beg)
52+
3753
def test_combined(self):
3854
ca = Argument("base", dict, [
3955
Argument("sub1", int, optional=True, default=1, alias=["sub1a"]),
@@ -43,10 +59,10 @@ def test_combined(self):
4359
ref1 = {"base":{"sub1":1, "sub2": "haha"}}
4460
self.assertDictEqual(ca.normalize(beg1), ref1)
4561
self.assertDictEqual(ca.normalize_value(beg1["base"]), ref1["base"])
46-
beg2 = {"base":{"sub1a": 2, "sub2a": "hoho"}}
62+
beg2 = {"base":{"sub1a": 2, "sub2a": "hoho", "_comment": None}}
4763
ref2 = {"base":{"sub1": 2, "sub2": "hoho"}}
48-
self.assertDictEqual(ca.normalize(beg2), ref2)
49-
self.assertDictEqual(ca.normalize_value(beg2["base"]), ref2["base"])
64+
self.assertDictEqual(ca.normalize(beg2, trim_pattern="_*"), ref2)
65+
self.assertDictEqual(ca.normalize_value(beg2["base"], trim_pattern="_*"), ref2["base"])
5066

5167
def test_complicated(self):
5268
ca = Argument("base", dict, [
@@ -79,10 +95,11 @@ def test_complicated(self):
7995
beg2 = {
8096
"base": {
8197
"sub1a": 2,
82-
"sub2a": [{"ss1a":22}, {}],
98+
"sub2a": [{"ss1a":22}, {"_comment1": None}],
8399
"vnt_flag": "type2",
84100
"sharedb": -3,
85-
"vnt2a": 223}
101+
"vnt2a": 223,
102+
"_comment2": None}
86103
}
87104
ref2 = {
88105
'base': {
@@ -92,8 +109,10 @@ def test_complicated(self):
92109
'shared': -3,
93110
'vnt2': 223}
94111
}
95-
self.assertDictEqual(ca.normalize(beg2), ref2)
96-
self.assertDictEqual(ca.normalize_value(beg2["base"]), ref2["base"])
112+
self.assertDictEqual(ca.normalize(beg2, trim_pattern="_*"), ref2)
113+
self.assertDictEqual(ca.normalize_value(beg2["base"], trim_pattern="_*"), ref2["base"])
114+
with self.assertRaises(ValueError):
115+
ca.normalize(beg2, trim_pattern="sub*")
97116

98117

99118
if __name__ == "__main__":

0 commit comments

Comments
 (0)