2727
2828INDENT = " " # doc is indented by four spaces
2929DUMMYHOOK = lambda a ,x : None
30- class _Flags (Enum ): NONE = 0
30+ class _Flags (Enum ): NONE = 0 # for no value in dict
3131
3232class Argument :
3333
@@ -112,30 +112,37 @@ def add_subvariant(self, flag_name: Union[str, "Variant"],
112112 def traverse (self , argdict : dict ,
113113 key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
114114 value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
115- sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ):
115+ sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
116+ variant_hook : Callable [["Variant" , dict ], None ] = DUMMYHOOK ):
116117 # first, do something with the key
117118 # then, take out the vaule and do something with it
118119 key_hook (self , argdict )
119120 if self .name in argdict :
120121 # this is the key step that we traverse into the tree
121- self .traverse_value (argdict [self .name ], key_hook , value_hook , sub_hook )
122+ self .traverse_value (argdict [self .name ],
123+ key_hook , value_hook , sub_hook , variant_hook )
122124
123125 def traverse_value (self , value : Any ,
124126 key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
125127 value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
126- sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ):
128+ sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
129+ variant_hook : Callable [["Variant" , dict ], None ] = DUMMYHOOK ):
127130 # this is not private, and can be called directly
128131 # in the condition where there is no leading key
129132 value_hook (self , value )
130133 if isinstance (value , dict ):
131134 sub_hook (self , value )
132- self ._traverse_subfield (value , key_hook , value_hook , sub_hook )
133- self ._traverse_subvariant (value , key_hook , value_hook , sub_hook )
135+ self ._traverse_subfield (value ,
136+ key_hook , value_hook , sub_hook , variant_hook )
137+ self ._traverse_subvariant (value ,
138+ key_hook , value_hook , sub_hook , variant_hook )
134139 if isinstance (value , list ) and self .repeat :
135140 for item in value :
136141 sub_hook (self , item )
137- self ._traverse_subfield (item , key_hook , value_hook , sub_hook )
138- self ._traverse_subvariant (item , key_hook , value_hook , sub_hook )
142+ self ._traverse_subfield (item ,
143+ key_hook , value_hook , sub_hook , variant_hook )
144+ self ._traverse_subvariant (item ,
145+ key_hook , value_hook , sub_hook , variant_hook )
139146
140147 def _traverse_subfield (self , value : dict , * args , ** kwargs ):
141148 assert isinstance (value , dict )
@@ -202,9 +209,13 @@ def normalize(self, argdict: dict, inplace: bool = False,
202209 if not inplace :
203210 argdict = deepcopy (argdict )
204211 if do_alias :
205- self .traverse (argdict , key_hook = Argument ._convert_alias )
212+ self .traverse (argdict ,
213+ key_hook = Argument ._convert_alias ,
214+ variant_hook = Variant ._convert_alias )
206215 if do_default :
207- self .traverse (argdict , key_hook = Argument ._assign_default )
216+ self .traverse (argdict ,
217+ key_hook = Argument ._assign_default ,
218+ variant_hook = Variant ._assign_default )
208219 if trim_pattern is not None :
209220 self ._trim_unrequired (argdict , trim_pattern , reserved = [self .name ])
210221 self .traverse (argdict , sub_hook = lambda a , d :
@@ -217,9 +228,13 @@ def normalize_value(self, value: Any, inplace: bool = False,
217228 if not inplace :
218229 value = deepcopy (value )
219230 if do_alias :
220- self .traverse_value (value , key_hook = Argument ._convert_alias )
231+ self .traverse_value (value ,
232+ key_hook = Argument ._convert_alias ,
233+ variant_hook = Variant ._convert_alias )
221234 if do_default :
222- self .traverse_value (value , key_hook = Argument ._assign_default )
235+ self .traverse_value (value ,
236+ key_hook = Argument ._assign_default ,
237+ variant_hook = Variant ._assign_default )
223238 if trim_pattern is not None :
224239 self .traverse_value (value , sub_hook = lambda a , d :
225240 Argument ._trim_unrequired (d , trim_pattern , a ._get_allowed_sub (d )))
@@ -314,6 +329,7 @@ def __init__(self,
314329 doc : str = "" ):
315330 self .flag_name = flag_name
316331 self .choice_dict = {}
332+ self .alias_dict = {}
317333 if choices is not None :
318334 self .extend_choices (choices )
319335 self .optional = optional
@@ -340,6 +356,13 @@ def extend_choices(self, choices: Iterable["Argument"]):
340356 raise ValueError (f"duplicate tag `{ tag } ` appears in "
341357 f"variant with flag `{ self .flag_name } `" )
342358 self .choice_dict [tag ] = arg
359+ # also update alias here
360+ for atag in arg .alias :
361+ if atag in self .choice_dict or atag in self .alias_dict :
362+ raise ValueError (f"duplicate alias tag `{ atag } ` appears in "
363+ f"variant with flag `{ self .flag_name } ` "
364+ f"and choice name `{ arg .name } `" )
365+ self .alias_dict [atag ] = arg .name
343366
344367 def set_default (self , default_tag : Union [bool , str ]):
345368 if not default_tag :
@@ -363,10 +386,16 @@ def add_choice(self, tag: Union[str, "Argument"],
363386 # above are creation part
364387 # below are general traverse part
365388
366- def traverse (self , argdict : dict , * args , ** kwargs ):
389+ def traverse (self , argdict : dict ,
390+ key_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
391+ value_hook : Callable [["Argument" , Any ], None ] = DUMMYHOOK ,
392+ sub_hook : Callable [["Argument" , dict ], None ] = DUMMYHOOK ,
393+ variant_hook : Callable [["Variant" , dict ], None ] = DUMMYHOOK ):
394+ variant_hook (self , argdict )
367395 choice = self ._load_choice (argdict )
368396 # here we use check_value to flatten the tag
369- choice .traverse_value (argdict , * args , ** kwargs )
397+ choice .traverse_value (argdict ,
398+ key_hook , value_hook , sub_hook , variant_hook )
370399
371400 def _load_choice (self , argdict : dict ) -> "Argument" :
372401 if self .flag_name in argdict :
@@ -384,6 +413,16 @@ def _get_allowed_sub(self, argdict: dict) -> List[str]:
384413 allowed .extend (choice ._get_allowed_sub (argdict ))
385414 return allowed
386415
416+ def _assign_default (self , argdict : dict ):
417+ if self .flag_name not in argdict and self .optional :
418+ argdict [self .flag_name ] = self .default_tag
419+
420+ def _convert_alias (self , argdict : dict ):
421+ if self .flag_name in argdict :
422+ tag = argdict [self .flag_name ]
423+ if tag not in self .choice_dict and tag in self .alias_dict :
424+ argdict [self .flag_name ] = self .alias_dict [tag ]
425+
387426 # above are type checking part
388427 # below are doc generation part
389428
0 commit comments