Skip to content

Commit e55a4bd

Browse files
committed
support alias for variant
1 parent 6bec152 commit e55a4bd

File tree

2 files changed

+56
-16
lines changed

2 files changed

+56
-16
lines changed

dargs/dargs.py

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
INDENT = " " # doc is indented by four spaces
2929
DUMMYHOOK = lambda a,x: None
30-
class _Flags(Enum): NONE = 0
30+
class _Flags(Enum): NONE = 0 # for no value in dict
3131

3232
class Argument:
3333

@@ -112,30 +112,37 @@ def add_subvariant(self, flag_name: Union[str, "Variant"],
112112
def traverse(self, argdict: dict,
113113
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
114114
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
115-
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK):
115+
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
116+
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
116117
# first, do something with the key
117118
# then, take out the vaule and do something with it
118119
key_hook(self, argdict)
119120
if self.name in argdict:
120121
# this is the key step that we traverse into the tree
121-
self.traverse_value(argdict[self.name], key_hook, value_hook, sub_hook)
122+
self.traverse_value(argdict[self.name],
123+
key_hook, value_hook, sub_hook, variant_hook)
122124

123125
def traverse_value(self, value: Any,
124126
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
125127
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
126-
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK):
128+
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
129+
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
127130
# this is not private, and can be called directly
128131
# in the condition where there is no leading key
129132
value_hook(self, value)
130133
if isinstance(value, dict):
131134
sub_hook(self, value)
132-
self._traverse_subfield(value, key_hook, value_hook, sub_hook)
133-
self._traverse_subvariant(value, key_hook, value_hook, sub_hook)
135+
self._traverse_subfield(value,
136+
key_hook, value_hook, sub_hook, variant_hook)
137+
self._traverse_subvariant(value,
138+
key_hook, value_hook, sub_hook, variant_hook)
134139
if isinstance(value, list) and self.repeat:
135140
for item in value:
136141
sub_hook(self, item)
137-
self._traverse_subfield(item, key_hook, value_hook, sub_hook)
138-
self._traverse_subvariant(item, key_hook, value_hook, sub_hook)
142+
self._traverse_subfield(item,
143+
key_hook, value_hook, sub_hook, variant_hook)
144+
self._traverse_subvariant(item,
145+
key_hook, value_hook, sub_hook, variant_hook)
139146

140147
def _traverse_subfield(self, value: dict, *args, **kwargs):
141148
assert isinstance(value, dict)
@@ -202,9 +209,13 @@ def normalize(self, argdict: dict, inplace: bool = False,
202209
if not inplace:
203210
argdict = deepcopy(argdict)
204211
if do_alias:
205-
self.traverse(argdict, key_hook=Argument._convert_alias)
212+
self.traverse(argdict,
213+
key_hook=Argument._convert_alias,
214+
variant_hook=Variant._convert_alias)
206215
if do_default:
207-
self.traverse(argdict, key_hook=Argument._assign_default)
216+
self.traverse(argdict,
217+
key_hook=Argument._assign_default,
218+
variant_hook=Variant._assign_default)
208219
if trim_pattern is not None:
209220
self._trim_unrequired(argdict, trim_pattern, reserved=[self.name])
210221
self.traverse(argdict, sub_hook=lambda a, d:
@@ -217,9 +228,13 @@ def normalize_value(self, value: Any, inplace: bool = False,
217228
if not inplace:
218229
value = deepcopy(value)
219230
if do_alias:
220-
self.traverse_value(value, key_hook=Argument._convert_alias)
231+
self.traverse_value(value,
232+
key_hook=Argument._convert_alias,
233+
variant_hook=Variant._convert_alias)
221234
if do_default:
222-
self.traverse_value(value, key_hook=Argument._assign_default)
235+
self.traverse_value(value,
236+
key_hook=Argument._assign_default,
237+
variant_hook=Variant._assign_default)
223238
if trim_pattern is not None:
224239
self.traverse_value(value, sub_hook=lambda a, d:
225240
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
@@ -314,6 +329,7 @@ def __init__(self,
314329
doc: str = ""):
315330
self.flag_name = flag_name
316331
self.choice_dict = {}
332+
self.alias_dict = {}
317333
if choices is not None:
318334
self.extend_choices(choices)
319335
self.optional = optional
@@ -340,6 +356,13 @@ def extend_choices(self, choices: Iterable["Argument"]):
340356
raise ValueError(f"duplicate tag `{tag}` appears in "
341357
f"variant with flag `{self.flag_name}`")
342358
self.choice_dict[tag] = arg
359+
# also update alias here
360+
for atag in arg.alias:
361+
if atag in self.choice_dict or atag in self.alias_dict:
362+
raise ValueError(f"duplicate alias tag `{atag}` appears in "
363+
f"variant with flag `{self.flag_name}` "
364+
f"and choice name `{arg.name}`")
365+
self.alias_dict[atag] = arg.name
343366

344367
def set_default(self, default_tag : Union[bool, str]):
345368
if not default_tag:
@@ -363,10 +386,16 @@ def add_choice(self, tag: Union[str, "Argument"],
363386
# above are creation part
364387
# below are general traverse part
365388

366-
def traverse(self, argdict: dict, *args, **kwargs):
389+
def traverse(self, argdict: dict,
390+
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
391+
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
392+
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
393+
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
394+
variant_hook(self, argdict)
367395
choice = self._load_choice(argdict)
368396
# here we use check_value to flatten the tag
369-
choice.traverse_value(argdict, *args, **kwargs)
397+
choice.traverse_value(argdict,
398+
key_hook, value_hook, sub_hook, variant_hook)
370399

371400
def _load_choice(self, argdict: dict) -> "Argument":
372401
if self.flag_name in argdict:
@@ -384,6 +413,16 @@ def _get_allowed_sub(self, argdict: dict) -> List[str]:
384413
allowed.extend(choice._get_allowed_sub(argdict))
385414
return allowed
386415

416+
def _assign_default(self, argdict: dict):
417+
if self.flag_name not in argdict and self.optional:
418+
argdict[self.flag_name] = self.default_tag
419+
420+
def _convert_alias(self, argdict: dict):
421+
if self.flag_name in argdict:
422+
tag = argdict[self.flag_name]
423+
if tag not in self.choice_dict and tag in self.alias_dict:
424+
argdict[self.flag_name] = self.alias_dict[tag]
425+
387426
# above are type checking part
388427
# below are doc generation part
389428

tests/test_normalizer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,15 @@ def test_complicated(self):
7979
Argument("type2", dict, [
8080
Argument("shared", int, optional=True, default=-2, alias=["sharedb"]),
8181
Argument("vnt2", int, optional=True, default=222, alias=["vnt2a"]),
82-
])
82+
], alias = ['type3'])
8383
], optional=True, default_tag="type1")
8484
])
8585
beg1 = {"base": {"sub2": [{}, {}]}}
8686
ref1 = {
8787
'base': {
8888
'sub1': 1,
8989
'sub2': [{'ss1': 21}, {'ss1': 21}],
90+
'vnt_flag': "type1",
9091
'shared': -1,
9192
'vnt1': 111}
9293
}
@@ -96,7 +97,7 @@ def test_complicated(self):
9697
"base": {
9798
"sub1a": 2,
9899
"sub2a": [{"ss1a":22}, {"_comment1": None}],
99-
"vnt_flag": "type2",
100+
"vnt_flag": "type3",
100101
"sharedb": -3,
101102
"vnt2a": 223,
102103
"_comment2": None}

0 commit comments

Comments
 (0)