1818"""
1919
2020
21- from typing import Union , Any , List , Iterable , Optional , Callable
21+ from typing import Union , Any , List , Dict , Iterable , Optional , Callable
2222from textwrap import indent
2323from copy import deepcopy
2424from enum import Enum
@@ -45,35 +45,36 @@ def __init__(self,
4545 doc : str = "" ):
4646 self .name = name
4747 self .dtype = dtype
48- assert sub_fields is None or all (isinstance (s , Argument ) for s in sub_fields )
49- self .sub_fields = sub_fields if sub_fields is not None else []
50- assert sub_variants is None or all (isinstance (s , Variant ) for s in sub_variants )
51- self .sub_variants = sub_variants if sub_variants is not None else []
48+ self .sub_fields : Dict [str , "Argument" ] = {}
49+ self .sub_variants : Dict [str , "Variant" ] = {}
5250 self .repeat = repeat
5351 self .optional = optional
5452 self .default = default
5553 self .alias = alias if alias is not None else []
5654 self .extra_check = extra_check
5755 self .doc = doc
56+ # adding subfields and subvariants
57+ self .extend_subfields (sub_fields )
58+ self .extend_subvariants (sub_variants )
5859 # handle the format of dtype, makeit a tuple
59- self .reorg_dtype ()
60+ self ._reorg_dtype ()
6061
6162 def __eq__ (self , other : "Argument" ) -> bool :
6263 # do not compare doc and default
6364 # since they do not enter to the type checking
6465 fkey = lambda f : f .name
6566 vkey = lambda v : v .flag_name
66- return (self .name == other .name
67- and set (self .dtype ) == set (other .dtype )
68- and sorted ( self .sub_fields , key = fkey ) == sorted ( other .sub_fields , key = fkey )
69- and sorted ( self .sub_variants , key = vkey ) == sorted ( other .sub_variants , key = vkey )
70- and self .repeat == other .repeat
71- and self .optional == other .optional )
67+ return (self .name == other .name
68+ and set (self .dtype ) == set (other .dtype )
69+ and self .sub_fields == other .sub_fields
70+ and self .sub_variants == other .sub_variants
71+ and self .repeat == other .repeat
72+ and self .optional == other .optional )
7273
7374 def __repr__ (self ) -> str :
7475 return f"<Argument { self .name } : { ' | ' .join (dd .__name__ for dd in self .dtype )} >"
7576
76- def reorg_dtype (self ):
77+ def _reorg_dtype (self ):
7778 if isinstance (self .dtype , type ) or self .dtype is None :
7879 self .dtype = [self .dtype ]
7980 # remove duplicate
@@ -88,35 +89,58 @@ def reorg_dtype(self):
8889
8990 def set_dtype (self , dtype : Union [None , type , Iterable [type ]]):
9091 self .dtype = dtype
91- self .reorg_dtype ()
92+ self ._reorg_dtype ()
9293
9394 def set_repeat (self , repeat : bool = True ):
9495 self .repeat = repeat
95- self .reorg_dtype ()
96+ self ._reorg_dtype ()
97+
98+ def extend_subfields (self , sub_fields : Optional [Iterable ["Argument" ]]):
99+ if sub_fields is None :
100+ return
101+ assert all (isinstance (s , Argument ) for s in sub_fields )
102+ update_nodup (self .sub_fields , ((s .name , s ) for s in sub_fields ),
103+ err_msg = f"building Argument `{ self .name } `" )
104+ self ._reorg_dtype ()
96105
97106 def add_subfield (self , name : Union [str , "Argument" ],
98107 * args , ** kwargs ) -> "Argument" :
99108 if isinstance (name , Argument ):
100109 newarg = name
101110 else :
102111 newarg = Argument (name , * args , ** kwargs )
103- self .sub_fields .append (newarg )
104- self .reorg_dtype ()
112+ self .extend_subfields ([newarg ])
105113 return newarg
114+
115+ def extend_subvariants (self , sub_variants : Optional [Iterable ["Variant" ]]):
116+ if sub_variants is None :
117+ return
118+ assert all (isinstance (s , Variant ) for s in sub_variants )
119+ update_nodup (self .sub_variants , ((s .flag_name , s ) for s in sub_variants ),
120+ exclude = self .sub_fields .keys (),
121+ err_msg = f"building Argument `{ self .name } `" )
122+ self ._reorg_dtype ()
106123
107124 def add_subvariant (self , flag_name : Union [str , "Variant" ],
108125 * args , ** kwargs ) -> "Variant" :
109126 if isinstance (flag_name , Variant ):
110127 newvrnt = flag_name
111128 else :
112129 newvrnt = Variant (flag_name , * args , ** kwargs )
113- self .sub_variants .append (newvrnt )
114- self .reorg_dtype ()
130+ self .extend_subvariants ([newvrnt ])
115131 return newvrnt
116132
117133 # above are creation part
118134 # below are general traverse part
119135
136+ def flatten_sub (self , value : dict ) -> Dict [str , "Argument" ]:
137+ sub_dicts = [self .sub_fields ]
138+ sub_dicts .extend (vrnt .flatten_sub (value ) for vrnt in self .sub_variants .values ())
139+ flat_subs = {}
140+ update_nodup (flat_subs , * sub_dicts ,
141+ err_msg = f"flattening variants of { self .name } " )
142+ return flat_subs
143+
120144 def traverse (self , argdict : dict ,
121145 key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
122146 value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
@@ -139,28 +163,25 @@ def traverse_value(self, value: Any,
139163 # in the condition where there is no leading key
140164 value_hook (self , value )
141165 if isinstance (value , dict ):
142- sub_hook (self , value )
143- self ._traverse_subfield (value ,
144- key_hook , value_hook , sub_hook , variant_hook )
145- self ._traverse_subvariant (value ,
166+ self ._traverse_sub (value ,
146167 key_hook , value_hook , sub_hook , variant_hook )
147168 if isinstance (value , list ) and self .repeat :
148169 for item in value :
149- sub_hook (self , item )
150- self ._traverse_subfield (item ,
151- key_hook , value_hook , sub_hook , variant_hook )
152- self ._traverse_subvariant (item ,
170+ self ._traverse_sub (item ,
153171 key_hook , value_hook , sub_hook , variant_hook )
154172
155- def _traverse_subfield (self , value : dict , * args , ** kwargs ):
173+ def _traverse_sub (self , value : dict ,
174+ key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
175+ value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
176+ sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
177+ variant_hook : Callable [["Variant" , dict ], None ] = DUMMYHOOK ):
156178 assert isinstance (value , dict )
157- for subarg in self .sub_fields :
158- subarg .traverse (value , * args , ** kwargs )
159-
160- def _traverse_subvariant (self , value : dict , * args , ** kwargs ):
161- assert isinstance (value , dict )
162- for subvrnt in self .sub_variants :
163- subvrnt .traverse (value , * args , ** kwargs )
179+ sub_hook (self , value )
180+ for subvrnt in self .sub_variants .values ():
181+ variant_hook (subvrnt , value )
182+ for subarg in self .flatten_sub (value ).values ():
183+ subarg .traverse (value ,
184+ key_hook , value_hook , sub_hook , variant_hook )
164185
165186 # above are general traverse part
166187 # below are type checking part
@@ -197,20 +218,12 @@ def _check_value(self, value: Any):
197218 "that fails to pass its extra checking" )
198219
199220 def _check_strict (self , value : dict ):
200- allowed = self ._get_allowed_sub (value )
201- allowed_set = set (allowed )
202- assert len (allowed ) == len (allowed_set ), "duplicated keys!"
221+ allowed_keys = self .flatten_sub (value ).keys ()
203222 for name in value .keys ():
204- if name not in allowed_set :
223+ if name not in allowed_keys :
205224 raise KeyError (f"undefined key `{ name } ` is "
206225 "not allowed in strict mode" )
207-
208- def _get_allowed_sub (self , value : dict ) -> List [str ]:
209- allowed = [subarg .name for subarg in self .sub_fields ]
210- for subvrnt in self .sub_variants :
211- allowed .extend (subvrnt ._get_allowed_sub (value ))
212- return allowed
213-
226+
214227 # above are type checking part
215228 # below are normalizing part
216229
@@ -222,15 +235,14 @@ def normalize(self, argdict: dict, inplace: bool = False,
222235 if do_alias :
223236 self .traverse (argdict ,
224237 key_hook = Argument ._convert_alias ,
225- variant_hook = Variant ._convert_alias )
238+ variant_hook = Variant ._convert_choice_alias )
226239 if do_default :
227240 self .traverse (argdict ,
228- key_hook = Argument ._assign_default ,
229- variant_hook = Variant ._assign_default )
241+ key_hook = Argument ._assign_default )
230242 if trim_pattern is not None :
231243 self ._trim_unrequired (argdict , trim_pattern , reserved = [self .name ])
232244 self .traverse (argdict , sub_hook = lambda a , d :
233- Argument ._trim_unrequired (d , trim_pattern , a ._get_allowed_sub ( d )))
245+ Argument ._trim_unrequired (d , trim_pattern , a .flatten_sub ( d ). keys ( )))
234246 return argdict
235247
236248 def normalize_value (self , value : Any , inplace : bool = False ,
@@ -241,14 +253,13 @@ def normalize_value(self, value: Any, inplace: bool = False,
241253 if do_alias :
242254 self .traverse_value (value ,
243255 key_hook = Argument ._convert_alias ,
244- variant_hook = Variant ._convert_alias )
256+ variant_hook = Variant ._convert_choice_alias )
245257 if do_default :
246258 self .traverse_value (value ,
247- key_hook = Argument ._assign_default ,
248- variant_hook = Variant ._assign_default )
259+ key_hook = Argument ._assign_default )
249260 if trim_pattern is not None :
250261 self .traverse_value (value , sub_hook = lambda a , d :
251- Argument ._trim_unrequired (d , trim_pattern , a ._get_allowed_sub ( d )))
262+ Argument ._trim_unrequired (d , trim_pattern , a .flatten_sub ( d ). keys ( )))
252263 return value
253264
254265 def _assign_default (self , argdict : dict ):
@@ -321,10 +332,10 @@ def gen_doc_body(self, paths: Optional[List[str]] = None, **kwargs) -> str:
321332 if self .sub_fields :
322333 # body_list.append("") # genetate a blank line
323334 # body_list.append("This argument accept the following sub arguments:")
324- for subarg in self .sub_fields :
335+ for subarg in self .sub_fields . values () :
325336 body_list .append (subarg .gen_doc (paths , ** kwargs ))
326337 if self .sub_variants :
327- for subvrnt in self .sub_variants :
338+ for subvrnt in self .sub_variants . values () :
328339 body_list .append (subvrnt .gen_doc (paths , ** kwargs ))
329340 body = "\n " .join (body_list )
330341 return body
@@ -340,10 +351,9 @@ def __init__(self,
340351 default_tag : str = "" , # this is indeed necessary in case of optional
341352 doc : str = "" ):
342353 self .flag_name = flag_name
343- self .choice_dict = {}
344- self .alias_dict = {}
345- if choices is not None :
346- self .extend_choices (choices )
354+ self .choice_dict : Dict [str , Argument ] = {}
355+ self .choice_alias : Dict [str , str ] = {}
356+ self .extend_choices (choices )
347357 self .optional = optional
348358 if optional and not default_tag :
349359 raise ValueError ("default_tag is needed if optional is set to be True" )
@@ -360,25 +370,6 @@ def __eq__(self, other: "Variant") -> bool:
360370 def __repr__ (self ) -> str :
361371 return f"<Variant { self .flag_name } in {{ { ', ' .join (self .choice_dict .keys ())} }}>"
362372
363- def extend_choices (self , choices : Iterable ["Argument" ]):
364- # choices is a list of arguments
365- # whose name is treated as the switch tag
366- # we convert it into a dict for better reference
367- # and avoid duplicate tags
368- for arg in choices :
369- tag = arg .name
370- if tag in self .choice_dict :
371- raise ValueError (f"duplicate tag `{ tag } ` appears in "
372- f"variant with flag `{ self .flag_name } `" )
373- self .choice_dict [tag ] = arg
374- # also update alias here
375- for atag in arg .alias :
376- if atag in self .choice_dict or atag in self .alias_dict :
377- raise ValueError (f"duplicate alias tag `{ atag } ` appears in "
378- f"variant with flag `{ self .flag_name } ` "
379- f"and choice name `{ arg .name } `" )
380- self .alias_dict [atag ] = arg .name
381-
382373 def set_default (self , default_tag : Union [bool , str ]):
383374 if not default_tag :
384375 self .optional = False
@@ -388,6 +379,21 @@ def set_default(self, default_tag : Union[bool, str]):
388379 self .optional = True
389380 self .default_tag = default_tag
390381
382+ def extend_choices (self , choices : Optional [Iterable ["Argument" ]]):
383+ # choices is a list of arguments
384+ # whose name is treated as the switch tag
385+ # we convert it into a dict for better reference
386+ # and avoid duplicate tags
387+ if choices is None :
388+ return
389+ update_nodup (self .choice_dict , ((c .name , c ) for c in choices ),
390+ exclude = {self .flag_name },
391+ err_msg = f"Variant with flag `{ self .flag_name } `" )
392+ update_nodup (self .choice_alias ,
393+ * (((a , c .name ) for a in c .alias ) for c in choices ),
394+ exclude = {self .flag_name , * self .choice_dict .keys ()},
395+ err_msg = f"building alias dict for Variant with flag `{ self .flag_name } `" )
396+
391397 def add_choice (self , tag : Union [str , "Argument" ],
392398 dtype : Union [None , type , Iterable [type ]] = dict ,
393399 * args , ** kwargs ) -> "Argument" :
@@ -397,22 +403,18 @@ def add_choice(self, tag: Union[str, "Argument"],
397403 newarg = Argument (tag , dtype , * args , ** kwargs )
398404 self .extend_choices ([newarg ])
399405 return newarg
406+
407+ def dummy_argument (self ):
408+ return Argument (name = self .flag_name , dtype = str ,
409+ optional = self .optional , default = self .default_tag ,
410+ sub_fields = None , sub_variants = None , repeat = False ,
411+ alias = None , extra_check = None ,
412+ doc = f"dummy Argument converted from Variant { self .flag_name } " )
400413
401414 # above are creation part
402- # below are general traverse part
403-
404- def traverse (self , argdict : dict ,
405- key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
406- value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
407- sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
408- variant_hook : Callable [["Variant" , dict ], None ] = DUMMYHOOK ):
409- variant_hook (self , argdict )
410- choice = self ._load_choice (argdict )
411- # here we use traverse_value to flatten the tag
412- choice .traverse_value (argdict ,
413- key_hook , value_hook , sub_hook , variant_hook )
415+ # below are helpers for traversing
414416
415- def _load_choice (self , argdict : dict ) -> "Argument" :
417+ def get_choice (self , argdict : dict ) -> "Argument" :
416418 if self .flag_name in argdict :
417419 tag = argdict [self .flag_name ]
418420 return self .choice_dict [tag ]
@@ -421,24 +423,20 @@ def _load_choice(self, argdict: dict) -> "Argument":
421423 else :
422424 raise KeyError (f"key `{ self .flag_name } ` is required "
423425 "to choose variant but not found." )
426+
427+ def flatten_sub (self , argdict : dict ) -> Dict [str , "Argument" ]:
428+ choice = self .get_choice (argdict )
429+ fields = {self .flag_name : self .dummy_argument (), # as a placeholder
430+ ** choice .flatten_sub (argdict )}
431+ return fields
424432
425- def _get_allowed_sub (self , argdict : dict ) -> List [str ]:
426- allowed = [self .flag_name ]
427- choice = self ._load_choice (argdict )
428- allowed .extend (choice ._get_allowed_sub (argdict ))
429- return allowed
430-
431- def _assign_default (self , argdict : dict ):
432- if self .flag_name not in argdict and self .optional :
433- argdict [self .flag_name ] = self .default_tag
434-
435- def _convert_alias (self , argdict : dict ):
433+ def _convert_choice_alias (self , argdict : dict ):
436434 if self .flag_name in argdict :
437435 tag = argdict [self .flag_name ]
438- if tag not in self .choice_dict and tag in self .alias_dict :
439- argdict [self .flag_name ] = self .alias_dict [tag ]
436+ if tag not in self .choice_dict and tag in self .choice_alias :
437+ argdict [self .flag_name ] = self .choice_alias [tag ]
440438
441- # above are type checking part
439+ # above are traversing part
442440 # below are doc generation part
443441
444442 def gen_doc (self , paths : Optional [List [str ]] = None , ** kwargs ) -> str :
@@ -479,3 +477,18 @@ def make_rst_refid(name):
479477 if not isinstance (name , str ):
480478 name = '/' .join (name )
481479 return f'.. raw:: html\n \n <a id="{ name } "></a>'
480+
481+
482+ def update_nodup (this : dict ,
483+ * others : Union [dict , Iterable [tuple ]],
484+ exclude : Optional [Iterable ] = None ,
485+ err_msg : Optional [str ] = None ):
486+ for pair in others :
487+ if isinstance (pair , dict ):
488+ pair = pair .items ()
489+ for k , v in pair :
490+ if k in this or (exclude and k in exclude ):
491+ raise ValueError (f"duplicate key `{ k } ` when updating dict"
492+ + "" if err_msg is None else f"in { err_msg } " )
493+ this [k ] = v
494+ return this
0 commit comments