Skip to content

Commit 7fc1a9f

Browse files
committed
refactor code to fix bug
1 parent d32eae0 commit 7fc1a9f

File tree

3 files changed

+170
-115
lines changed

3 files changed

+170
-115
lines changed

dargs/dargs.py

Lines changed: 119 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
"""
1919

2020

21-
from typing import Union, Any, List, Iterable, Optional, Callable
21+
from typing import Union, Any, List, Dict, Iterable, Optional, Callable
2222
from textwrap import indent
2323
from copy import deepcopy
2424
from enum import Enum
@@ -45,35 +45,36 @@ def __init__(self,
4545
doc: str = ""):
4646
self.name = name
4747
self.dtype = dtype
48-
assert sub_fields is None or all(isinstance(s, Argument) for s in sub_fields)
49-
self.sub_fields = sub_fields if sub_fields is not None else []
50-
assert sub_variants is None or all(isinstance(s, Variant) for s in sub_variants)
51-
self.sub_variants = sub_variants if sub_variants is not None else []
48+
self.sub_fields : Dict[str, "Argument"] = {}
49+
self.sub_variants : Dict[str, "Variant"] = {}
5250
self.repeat = repeat
5351
self.optional = optional
5452
self.default = default
5553
self.alias = alias if alias is not None else []
5654
self.extra_check = extra_check
5755
self.doc = doc
56+
# adding subfields and subvariants
57+
self.extend_subfields(sub_fields)
58+
self.extend_subvariants(sub_variants)
5859
# handle the format of dtype, makeit a tuple
59-
self.reorg_dtype()
60+
self._reorg_dtype()
6061

6162
def __eq__(self, other: "Argument") -> bool:
6263
# do not compare doc and default
6364
# since they do not enter to the type checking
6465
fkey = lambda f: f.name
6566
vkey = lambda v: v.flag_name
66-
return (self.name == other.name
67-
and set(self.dtype) == set(other.dtype)
68-
and sorted(self.sub_fields, key=fkey) == sorted(other.sub_fields, key=fkey)
69-
and sorted(self.sub_variants, key=vkey) == sorted(other.sub_variants, key=vkey)
70-
and self.repeat == other.repeat
71-
and self.optional == other.optional)
67+
return (self.name == other.name
68+
and set(self.dtype) == set(other.dtype)
69+
and self.sub_fields == other.sub_fields
70+
and self.sub_variants == other.sub_variants
71+
and self.repeat == other.repeat
72+
and self.optional == other.optional)
7273

7374
def __repr__(self) -> str:
7475
return f"<Argument {self.name}: {' | '.join(dd.__name__ for dd in self.dtype)}>"
7576

76-
def reorg_dtype(self):
77+
def _reorg_dtype(self):
7778
if isinstance(self.dtype, type) or self.dtype is None:
7879
self.dtype = [self.dtype]
7980
# remove duplicate
@@ -88,35 +89,58 @@ def reorg_dtype(self):
8889

8990
def set_dtype(self, dtype: Union[None, type, Iterable[type]]):
9091
self.dtype = dtype
91-
self.reorg_dtype()
92+
self._reorg_dtype()
9293

9394
def set_repeat(self, repeat: bool = True):
9495
self.repeat = repeat
95-
self.reorg_dtype()
96+
self._reorg_dtype()
97+
98+
def extend_subfields(self, sub_fields: Optional[Iterable["Argument"]]):
99+
if sub_fields is None:
100+
return
101+
assert all(isinstance(s, Argument) for s in sub_fields)
102+
update_nodup(self.sub_fields, ((s.name, s) for s in sub_fields),
103+
err_msg=f"building Argument `{self.name}`")
104+
self._reorg_dtype()
96105

97106
def add_subfield(self, name: Union[str, "Argument"],
98107
*args, **kwargs) -> "Argument":
99108
if isinstance(name, Argument):
100109
newarg = name
101110
else:
102111
newarg = Argument(name, *args, **kwargs)
103-
self.sub_fields.append(newarg)
104-
self.reorg_dtype()
112+
self.extend_subfields([newarg])
105113
return newarg
114+
115+
def extend_subvariants(self, sub_variants: Optional[Iterable["Variant"]]):
116+
if sub_variants is None:
117+
return
118+
assert all(isinstance(s, Variant) for s in sub_variants)
119+
update_nodup(self.sub_variants, ((s.flag_name, s) for s in sub_variants),
120+
exclude=self.sub_fields.keys(),
121+
err_msg=f"building Argument `{self.name}`")
122+
self._reorg_dtype()
106123

107124
def add_subvariant(self, flag_name: Union[str, "Variant"],
108125
*args, **kwargs) -> "Variant":
109126
if isinstance(flag_name, Variant):
110127
newvrnt = flag_name
111128
else:
112129
newvrnt = Variant(flag_name, *args, **kwargs)
113-
self.sub_variants.append(newvrnt)
114-
self.reorg_dtype()
130+
self.extend_subvariants([newvrnt])
115131
return newvrnt
116132

117133
# above are creation part
118134
# below are general traverse part
119135

136+
def flatten_sub(self, value: dict) -> Dict[str, "Argument"]:
137+
sub_dicts = [self.sub_fields]
138+
sub_dicts.extend(vrnt.flatten_sub(value) for vrnt in self.sub_variants.values())
139+
flat_subs = {}
140+
update_nodup(flat_subs, *sub_dicts,
141+
err_msg=f"flattening variants of {self.name}")
142+
return flat_subs
143+
120144
def traverse(self, argdict: dict,
121145
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
122146
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
@@ -139,28 +163,25 @@ def traverse_value(self, value: Any,
139163
# in the condition where there is no leading key
140164
value_hook(self, value)
141165
if isinstance(value, dict):
142-
sub_hook(self, value)
143-
self._traverse_subfield(value,
144-
key_hook, value_hook, sub_hook, variant_hook)
145-
self._traverse_subvariant(value,
166+
self._traverse_sub(value,
146167
key_hook, value_hook, sub_hook, variant_hook)
147168
if isinstance(value, list) and self.repeat:
148169
for item in value:
149-
sub_hook(self, item)
150-
self._traverse_subfield(item,
151-
key_hook, value_hook, sub_hook, variant_hook)
152-
self._traverse_subvariant(item,
170+
self._traverse_sub(item,
153171
key_hook, value_hook, sub_hook, variant_hook)
154172

155-
def _traverse_subfield(self, value: dict, *args, **kwargs):
173+
def _traverse_sub(self, value: dict,
174+
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
175+
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
176+
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
177+
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
156178
assert isinstance(value, dict)
157-
for subarg in self.sub_fields:
158-
subarg.traverse(value, *args, **kwargs)
159-
160-
def _traverse_subvariant(self, value: dict, *args, **kwargs):
161-
assert isinstance(value, dict)
162-
for subvrnt in self.sub_variants:
163-
subvrnt.traverse(value, *args, **kwargs)
179+
sub_hook(self, value)
180+
for subvrnt in self.sub_variants.values():
181+
variant_hook(subvrnt, value)
182+
for subarg in self.flatten_sub(value).values():
183+
subarg.traverse(value,
184+
key_hook, value_hook, sub_hook, variant_hook)
164185

165186
# above are general traverse part
166187
# below are type checking part
@@ -197,20 +218,12 @@ def _check_value(self, value: Any):
197218
"that fails to pass its extra checking")
198219

199220
def _check_strict(self, value: dict):
200-
allowed = self._get_allowed_sub(value)
201-
allowed_set = set(allowed)
202-
assert len(allowed) == len(allowed_set), "duplicated keys!"
221+
allowed_keys = self.flatten_sub(value).keys()
203222
for name in value.keys():
204-
if name not in allowed_set:
223+
if name not in allowed_keys:
205224
raise KeyError(f"undefined key `{name}` is "
206225
"not allowed in strict mode")
207-
208-
def _get_allowed_sub(self, value: dict) -> List[str]:
209-
allowed = [subarg.name for subarg in self.sub_fields]
210-
for subvrnt in self.sub_variants:
211-
allowed.extend(subvrnt._get_allowed_sub(value))
212-
return allowed
213-
226+
214227
# above are type checking part
215228
# below are normalizing part
216229

@@ -222,15 +235,14 @@ def normalize(self, argdict: dict, inplace: bool = False,
222235
if do_alias:
223236
self.traverse(argdict,
224237
key_hook=Argument._convert_alias,
225-
variant_hook=Variant._convert_alias)
238+
variant_hook=Variant._convert_choice_alias)
226239
if do_default:
227240
self.traverse(argdict,
228-
key_hook=Argument._assign_default,
229-
variant_hook=Variant._assign_default)
241+
key_hook=Argument._assign_default)
230242
if trim_pattern is not None:
231243
self._trim_unrequired(argdict, trim_pattern, reserved=[self.name])
232244
self.traverse(argdict, sub_hook=lambda a, d:
233-
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
245+
Argument._trim_unrequired(d, trim_pattern, a.flatten_sub(d).keys()))
234246
return argdict
235247

236248
def normalize_value(self, value: Any, inplace: bool = False,
@@ -241,14 +253,13 @@ def normalize_value(self, value: Any, inplace: bool = False,
241253
if do_alias:
242254
self.traverse_value(value,
243255
key_hook=Argument._convert_alias,
244-
variant_hook=Variant._convert_alias)
256+
variant_hook=Variant._convert_choice_alias)
245257
if do_default:
246258
self.traverse_value(value,
247-
key_hook=Argument._assign_default,
248-
variant_hook=Variant._assign_default)
259+
key_hook=Argument._assign_default)
249260
if trim_pattern is not None:
250261
self.traverse_value(value, sub_hook=lambda a, d:
251-
Argument._trim_unrequired(d, trim_pattern, a._get_allowed_sub(d)))
262+
Argument._trim_unrequired(d, trim_pattern, a.flatten_sub(d).keys()))
252263
return value
253264

254265
def _assign_default(self, argdict: dict):
@@ -321,10 +332,10 @@ def gen_doc_body(self, paths: Optional[List[str]] = None, **kwargs) -> str:
321332
if self.sub_fields:
322333
# body_list.append("") # genetate a blank line
323334
# body_list.append("This argument accept the following sub arguments:")
324-
for subarg in self.sub_fields:
335+
for subarg in self.sub_fields.values():
325336
body_list.append(subarg.gen_doc(paths, **kwargs))
326337
if self.sub_variants:
327-
for subvrnt in self.sub_variants:
338+
for subvrnt in self.sub_variants.values():
328339
body_list.append(subvrnt.gen_doc(paths, **kwargs))
329340
body = "\n".join(body_list)
330341
return body
@@ -340,10 +351,9 @@ def __init__(self,
340351
default_tag: str = "", # this is indeed necessary in case of optional
341352
doc: str = ""):
342353
self.flag_name = flag_name
343-
self.choice_dict = {}
344-
self.alias_dict = {}
345-
if choices is not None:
346-
self.extend_choices(choices)
354+
self.choice_dict : Dict[str, Argument] = {}
355+
self.choice_alias : Dict[str, str] = {}
356+
self.extend_choices(choices)
347357
self.optional = optional
348358
if optional and not default_tag:
349359
raise ValueError("default_tag is needed if optional is set to be True")
@@ -360,25 +370,6 @@ def __eq__(self, other: "Variant") -> bool:
360370
def __repr__(self) -> str:
361371
return f"<Variant {self.flag_name} in {{ {', '.join(self.choice_dict.keys())} }}>"
362372

363-
def extend_choices(self, choices: Iterable["Argument"]):
364-
# choices is a list of arguments
365-
# whose name is treated as the switch tag
366-
# we convert it into a dict for better reference
367-
# and avoid duplicate tags
368-
for arg in choices:
369-
tag = arg.name
370-
if tag in self.choice_dict:
371-
raise ValueError(f"duplicate tag `{tag}` appears in "
372-
f"variant with flag `{self.flag_name}`")
373-
self.choice_dict[tag] = arg
374-
# also update alias here
375-
for atag in arg.alias:
376-
if atag in self.choice_dict or atag in self.alias_dict:
377-
raise ValueError(f"duplicate alias tag `{atag}` appears in "
378-
f"variant with flag `{self.flag_name}` "
379-
f"and choice name `{arg.name}`")
380-
self.alias_dict[atag] = arg.name
381-
382373
def set_default(self, default_tag : Union[bool, str]):
383374
if not default_tag:
384375
self.optional = False
@@ -388,6 +379,21 @@ def set_default(self, default_tag : Union[bool, str]):
388379
self.optional = True
389380
self.default_tag = default_tag
390381

382+
def extend_choices(self, choices: Optional[Iterable["Argument"]]):
383+
# choices is a list of arguments
384+
# whose name is treated as the switch tag
385+
# we convert it into a dict for better reference
386+
# and avoid duplicate tags
387+
if choices is None:
388+
return
389+
update_nodup(self.choice_dict, ((c.name, c) for c in choices),
390+
exclude={self.flag_name},
391+
err_msg=f"Variant with flag `{self.flag_name}`")
392+
update_nodup(self.choice_alias,
393+
*(((a, c.name) for a in c.alias) for c in choices),
394+
exclude={self.flag_name, *self.choice_dict.keys()},
395+
err_msg=f"building alias dict for Variant with flag `{self.flag_name}`")
396+
391397
def add_choice(self, tag: Union[str, "Argument"],
392398
dtype: Union[None, type, Iterable[type]] = dict,
393399
*args, **kwargs) -> "Argument":
@@ -397,22 +403,18 @@ def add_choice(self, tag: Union[str, "Argument"],
397403
newarg = Argument(tag, dtype, *args, **kwargs)
398404
self.extend_choices([newarg])
399405
return newarg
406+
407+
def dummy_argument(self):
408+
return Argument(name=self.flag_name, dtype=str,
409+
optional=self.optional, default=self.default_tag,
410+
sub_fields=None, sub_variants=None, repeat=False,
411+
alias=None, extra_check=None,
412+
doc=f"dummy Argument converted from Variant {self.flag_name}")
400413

401414
# above are creation part
402-
# below are general traverse part
403-
404-
def traverse(self, argdict: dict,
405-
key_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
406-
value_hook: Callable[["Argument", Any], None] = DUMMYHOOK,
407-
sub_hook: Callable[["Argument", dict], None] = DUMMYHOOK,
408-
variant_hook: Callable[["Variant", dict], None] = DUMMYHOOK):
409-
variant_hook(self, argdict)
410-
choice = self._load_choice(argdict)
411-
# here we use traverse_value to flatten the tag
412-
choice.traverse_value(argdict,
413-
key_hook, value_hook, sub_hook, variant_hook)
415+
# below are helpers for traversing
414416

415-
def _load_choice(self, argdict: dict) -> "Argument":
417+
def get_choice(self, argdict: dict) -> "Argument":
416418
if self.flag_name in argdict:
417419
tag = argdict[self.flag_name]
418420
return self.choice_dict[tag]
@@ -421,24 +423,20 @@ def _load_choice(self, argdict: dict) -> "Argument":
421423
else:
422424
raise KeyError(f"key `{self.flag_name}` is required "
423425
"to choose variant but not found.")
426+
427+
def flatten_sub(self, argdict: dict) -> Dict[str, "Argument"]:
428+
choice = self.get_choice(argdict)
429+
fields = {self.flag_name: self.dummy_argument(), # as a placeholder
430+
**choice.flatten_sub(argdict)}
431+
return fields
424432

425-
def _get_allowed_sub(self, argdict: dict) -> List[str]:
426-
allowed = [self.flag_name]
427-
choice = self._load_choice(argdict)
428-
allowed.extend(choice._get_allowed_sub(argdict))
429-
return allowed
430-
431-
def _assign_default(self, argdict: dict):
432-
if self.flag_name not in argdict and self.optional:
433-
argdict[self.flag_name] = self.default_tag
434-
435-
def _convert_alias(self, argdict: dict):
433+
def _convert_choice_alias(self, argdict: dict):
436434
if self.flag_name in argdict:
437435
tag = argdict[self.flag_name]
438-
if tag not in self.choice_dict and tag in self.alias_dict:
439-
argdict[self.flag_name] = self.alias_dict[tag]
436+
if tag not in self.choice_dict and tag in self.choice_alias:
437+
argdict[self.flag_name] = self.choice_alias[tag]
440438

441-
# above are type checking part
439+
# above are traversing part
442440
# below are doc generation part
443441

444442
def gen_doc(self, paths: Optional[List[str]] = None, **kwargs) -> str:
@@ -479,3 +477,18 @@ def make_rst_refid(name):
479477
if not isinstance(name, str):
480478
name = '/'.join(name)
481479
return f'.. raw:: html\n\n <a id="{name}"></a>'
480+
481+
482+
def update_nodup(this : dict,
483+
*others : Union[dict, Iterable[tuple]],
484+
exclude : Optional[Iterable] = None,
485+
err_msg : Optional[str] = None):
486+
for pair in others:
487+
if isinstance(pair, dict):
488+
pair = pair.items()
489+
for k, v in pair:
490+
if k in this or (exclude and k in exclude):
491+
raise ValueError(f"duplicate key `{k}` when updating dict"
492+
+ "" if err_msg is None else f"in {err_msg}")
493+
this[k] = v
494+
return this

0 commit comments

Comments
 (0)