44"""
55import abc
66import inspect
7+ import argparse
8+ import contextlib
9+ import dataclasses
710from argparse import ArgumentParser
8- from typing import Dict , Any , Tuple , NamedTuple
11+ from typing import Dict , Any , Tuple , NamedTuple , Type
912
13+ try :
14+ from typing import get_origin , get_args
15+ except ImportError :
16+ # Added in Python 3.8
17+ def get_origin (t ):
18+ return getattr (t , "__origin__" , None )
19+
20+ def get_args (t ):
21+ return getattr (t , "__args__" , None )
22+
23+
24+ from .util .cli .arg import Arg
1025from .util .data import traverse_config_set , traverse_config_get
1126
1227from .util .entrypoint import Entrypoint
1328
1429from .log import LOGGER
1530
1631
32+ class ParseExpandAction (argparse .Action ):
33+ def __call__ (self , parser , namespace , values , option_string = None ):
34+ if not isinstance (values , list ):
35+ values = [values ]
36+ setattr (namespace , self .dest , self .LIST_CLS (* values ))
37+
38+
39+ # Maps classes to their ParseClassNameAction
40+ LIST_ACTIONS : Dict [Type , Type ] = {}
41+
42+
43+ def list_action (list_cls ):
44+ """
45+ Action to take a list of values and make them values in the list of type
46+ list_class. Which will be a class descendent from AsyncContextManagerList.
47+ """
48+ LIST_ACTIONS .setdefault (
49+ list_cls ,
50+ type (
51+ f"Parse{ list_cls .__qualname__ } Action" ,
52+ (ParseExpandAction ,),
53+ {"LIST_CLS" : list_cls },
54+ ),
55+ )
56+ return LIST_ACTIONS [list_cls ]
57+
58+
1759class MissingArg (Exception ):
1860 """
1961 Raised when a BaseConfigurable is missing an argument from the args dict it
@@ -64,6 +106,13 @@ def __str__(self):
64106 return repr (self )
65107
66108
109+ def config (cls ):
110+ """
111+ Decorator to create a dataclass
112+ """
113+ return dataclasses .dataclass (eq = True , frozen = True )(cls )
114+
115+
67116class ConfigurableParsingNamespace (object ):
68117 def __init__ (self ):
69118 self .dest = None
@@ -144,9 +193,14 @@ def config_get(cls, config, above, *path) -> BaseConfig:
144193 args_above = cls .add_orig_label () + list (path )
145194 label_above = cls .add_label (* above ) + list (path )
146195 no_label_above = cls .add_label (* above )[:- 1 ] + list (path )
196+
197+ arg = None
147198 try :
148199 arg = traverse_config_get (args , * args_above )
149200 except KeyError as error :
201+ pass
202+
203+ if arg is None :
150204 raise MissingArg (
151205 "Arg %r missing from %s%s%s"
152206 % (
@@ -155,23 +209,30 @@ def config_get(cls, config, above, *path) -> BaseConfig:
155209 "." if args_above [:- 1 ] else "" ,
156210 "." .join (args_above [:- 1 ]),
157211 )
158- ) from error
159- try :
212+ )
213+
214+ value = None
215+ # Try to get the value specific to this label
216+ with contextlib .suppress (KeyError ):
160217 value = traverse_config_get (config , * label_above )
161- except KeyError as error :
162- try :
218+
219+ # Try to get the value specific to this plugin
220+ if value is None :
221+ with contextlib .suppress (KeyError ):
163222 value = traverse_config_get (config , * no_label_above )
164- except KeyError as error :
165- if "default" in arg :
166- return arg ["default" ]
167- raise MissingConfig (
168- "%s missing %r from %s"
169- % (
170- cls .__qualname__ ,
171- label_above [- 1 ],
172- "." .join (label_above [:- 1 ]),
173- )
174- ) from error
223+
224+ if value is None :
225+ # Return default if not found and available
226+ if "default" in arg :
227+ return arg ["default" ]
228+ raise MissingConfig (
229+ "%s missing %r from %s"
230+ % (
231+ cls .__qualname__ ,
232+ label_above [- 1 ],
233+ "." .join (label_above [:- 1 ]),
234+ )
235+ )
175236
176237 if value is None and "default" in arg :
177238 return arg ["default" ]
@@ -197,19 +258,67 @@ def config_get(cls, config, above, *path) -> BaseConfig:
197258 return value
198259
199260 @classmethod
200- @abc .abstractmethod
201- def args (cls , * above ) -> Dict [str , Any ]:
261+ def args (cls , args , * above ) -> Dict [str , Arg ]:
202262 """
203263 Return a dict containing arguments required for this class
204264 """
265+ if getattr (cls , "CONFIG" , None ) is None :
266+ raise AttributeError (
267+ f"{ cls .__qualname__ } requires CONFIG property or implementation of args() classmethod"
268+ )
269+ for field in dataclasses .fields (cls .CONFIG ):
270+ arg = Arg (type = field .type )
271+ # HACK For detecting dataclasses._MISSING_TYPE
272+ if "dataclasses._MISSING_TYPE" not in repr (field .default ):
273+ arg ["default" ] = field .default
274+ if field .type == bool :
275+ arg ["action" ] = "store_true"
276+ elif inspect .isclass (field .type ):
277+ if issubclass (field .type , list ):
278+ arg ["nargs" ] = "+"
279+ if not hasattr (field .type , "SINGLETON" ):
280+ raise AttributeError (
281+ f"{ field .type .__qualname__ } missing attribute SINGLETON"
282+ )
283+ arg ["action" ] = list_action (field .type )
284+ arg ["type" ] = field .type .SINGLETON
285+ if hasattr (arg ["type" ], "load" ):
286+ # TODO (python3.8) Use Protocol
287+ arg ["type" ] = arg ["type" ].load
288+ elif get_origin (field .type ) is list :
289+ arg ["type" ] = get_args (field .type )[0 ]
290+ arg ["nargs" ] = "+"
291+ if "help" in field .metadata :
292+ arg ["help" ] = field .metadata ["help" ]
293+ cls .config_set (args , above , field .name , arg )
294+ return args
205295
206296 @classmethod
207- @abc .abstractmethod
208297 def config (cls , config , * above ):
209298 """
210299 Create the BaseConfig required to instantiate this class by parsing the
211300 config dict.
212301 """
302+ if getattr (cls , "CONFIG" , None ) is None :
303+ raise AttributeError (
304+ f"{ cls .__qualname__ } requires CONFIG property or implementation of config() classmethod"
305+ )
306+ # Build the arguments to the CONFIG class
307+ kwargs : Dict [str , Any ] = {}
308+ for field in dataclasses .fields (cls .CONFIG ):
309+ kwargs [field .name ] = got = cls .config_get (
310+ config , above , field .name
311+ )
312+ if inspect .isclass (got ) and issubclass (got , BaseConfigurable ):
313+ try :
314+ kwargs [field .name ] = got .withconfig (
315+ config , * above , * cls .add_label ()
316+ )
317+ except MissingConfig :
318+ kwargs [field .name ] = got .withconfig (
319+ config , * above , * cls .add_label ()[:- 1 ]
320+ )
321+ return cls .CONFIG (** kwargs )
213322
214323 @classmethod
215324 def withconfig (cls , config , * above ):
0 commit comments