77import sys
88import ast
99import inspect
10+ import dataclasses
1011from collections import namedtuple
11- from typing import Dict
12+ from typing import Dict , Optional , Tuple , Type , Any
1213
1314from sklearn .neural_network import MLPClassifier
1415from sklearn .neighbors import KNeighborsClassifier
4344 Ridge ,
4445)
4546
47+ from dffml .base import make_config , field
4648from dffml .util .cli .arg import Arg
4749from dffml .util .entrypoint import entry_point
4850from dffml_model_scikit .scikit_base import Scikit , ScikitContext
@@ -70,6 +72,115 @@ class NoDefaultValue:
7072 pass
7173
7274
75+ class ParameterNotInDocString (Exception ):
76+ """
77+ Raised when a scikit class has a parameter in its ``__init__`` which was not
78+ present in it's docstring. Therefore we have no typing information for it.
79+ """
80+
81+
82+ def scikit_get_default (type_str ):
83+ if not "default" in type_str :
84+ return dataclasses .MISSING
85+ type_str = type_str [type_str .index ("default" ) :]
86+ type_str = type_str .replace ("default" , "" )
87+ type_str = type_str .replace (")" , "" )
88+ type_str = type_str .replace ("=" , "" )
89+ type_str = type_str .replace ('"' , "" )
90+ type_str = type_str .replace ("'" , "" )
91+ type_str = type_str .strip ()
92+ if type_str == "None" :
93+ return None
94+ return type_str
95+
96+
97+ SCIKIT_DOCS_TYPE_MAP = {
98+ "int" : int ,
99+ "integer" : int ,
100+ "str" : str ,
101+ "string" : str ,
102+ "float" : float ,
103+ "dict" : dict ,
104+ "bool" : bool ,
105+ }
106+
107+
108+ def scikit_doc_to_field (type_str , param ):
109+ default = param .default
110+ if default is inspect .Parameter .empty :
111+ default = scikit_get_default (type_str )
112+
113+ type_cls = Any
114+
115+ # Set of choices
116+ if "{'" in type_str and "'}" in type_str :
117+ type_cls = str
118+ elif "{" in type_str and "}" in type_str :
119+ type_cls = int
120+ if "." in type_str :
121+ type_cls = float
122+ else :
123+ type_split = list (
124+ map (lambda x : x .lower (), type_str .replace ("," , "" ).split ())
125+ )
126+ for scikit_type_name , python_type in SCIKIT_DOCS_TYPE_MAP .items ():
127+ if scikit_type_name in type_split :
128+ type_cls = python_type
129+
130+ if type_cls == Any and default != None :
131+ type_cls = type (default )
132+
133+ return type_cls , field (type_str , default = default )
134+
135+
136+ def mkscikit_config_cls (
137+ name : str ,
138+ cls : Type ,
139+ properties : Optional [Dict [str , Tuple [Type , field ]]] = None ,
140+ ):
141+ """
142+ Given a scikit class, read its docstring and ``__init__`` parameters to
143+ generate a config class with properties containing the correct types,
144+ and default values.
145+ """
146+ if properties is None :
147+ properties = {}
148+
149+ parameters = inspect .signature (cls ).parameters
150+ docstring = inspect .getdoc (cls )
151+
152+ docparams = {}
153+
154+ # Parse parameters and their datatypes from docstring
155+ last_param_name = None
156+ for line in docstring .split ("\n " ):
157+ if not ":" in line :
158+ continue
159+ param_name , dtypes = line .split (":" , maxsplit = 1 )
160+ param_name = param_name .strip ()
161+ dtypes = dtypes .strip ()
162+ if not param_name in parameters or param_name in docparams :
163+ continue
164+ docparams [param_name ] = dtypes
165+ last_param_name = param_name
166+
167+ # Ensure all required parameters are present in docstring
168+ for param_name , param in parameters .items ():
169+ if param_name in ["args" , "kwargs" ]:
170+ continue
171+ if not param_name in docparams :
172+ raise ParameterNotInDocString (
173+ f"{ param_name } for { cls .__qualname__ } "
174+ )
175+ properties [param_name ] = scikit_doc_to_field (
176+ docparams [param_name ], param
177+ )
178+
179+ return make_config (
180+ name , [tuple ([key ] + list (value )) for key , value in properties .items ()]
181+ )
182+
183+
73184for entry_point_name , name , cls , applicable_features_function in [
74185 (
75186 "scikitknn" ,
@@ -129,15 +240,10 @@ class NoDefaultValue:
129240 ExtraTreesClassifier ,
130241 applicable_features ,
131242 ),
132- (
133- "scikitbgc" ,
134- "BaggingClassifier" ,
135- BaggingClassifier ,
136- applicable_features ,
137- ),
138- ("scikiteln" , "ElasticNet" , ElasticNet , applicable_features ,),
139- ("scikitbyr" , "BayesianRidge" , BayesianRidge , applicable_features ,),
140- ("scikitlas" , "Lasso" , Lasso , applicable_features ,),
243+ ("scikitbgc" , "BaggingClassifier" , BaggingClassifier , applicable_features ),
244+ ("scikiteln" , "ElasticNet" , ElasticNet , applicable_features ),
245+ ("scikitbyr" , "BayesianRidge" , BayesianRidge , applicable_features ),
246+ ("scikitlas" , "Lasso" , Lasso , applicable_features ),
141247 ("scikitard" , "ARDRegression" , ARDRegression , applicable_features ),
142248 ("scikitrsc" , "RANSACRegressor" , RANSACRegressor , applicable_features ),
143249 ("scikitbnb" , "BernoulliNB" , BernoulliNB , applicable_features ),
@@ -170,95 +276,26 @@ class NoDefaultValue:
170276 ("scikitlars" , "Lars" , Lars , applicable_features ),
171277]:
172278
173- parameters = inspect .signature (cls ).parameters
174- defaults = [
175- os .path .join (
176- os .path .expanduser ("~" ),
177- ".cache" ,
178- "dffml" ,
179- f"scikit-{ entry_point_name } " ,
180- ),
181- NoDefaultValue ,
182- ] + [
183- param .default
184- for name , param in parameters .items ()
185- if param .default != inspect ._empty
186- ]
187- dffml_config = namedtuple (
279+ dffml_config = mkscikit_config_cls (
188280 name + "ModelConfig" ,
189- ["directory" , "predict" , "features" ]
190- + [
191- param .name
192- for _ , param in parameters .items ()
193- if param .default != inspect ._empty
194- ],
195- defaults = defaults ,
196- )
197-
198- setattr (sys .modules [__name__ ], dffml_config .__qualname__ , dffml_config )
199-
200- @classmethod
201- def args (cls , args , * above ) -> Dict [str , Arg ]:
202- cls .config_set (
203- args ,
204- above ,
205- "directory" ,
206- Arg (
207- default = os .path .join (
208- os .path .expanduser ("~" ),
209- ".cache" ,
210- "dffml" ,
211- f"scikit-{ entry_point_name } " ,
281+ cls ,
282+ properties = {
283+ "directory" : (
284+ str ,
285+ field (
286+ "Directory where state should be saved" ,
287+ default = os .path .join (
288+ os .path .expanduser ("~" ),
289+ ".cache" ,
290+ "dffml" ,
291+ f"scikit-{ entry_point_name } " ,
292+ ),
212293 ),
213- help = "Directory where state should be saved" ,
214- ),
215- )
216- cls .config_set (
217- args ,
218- above ,
219- "predict" ,
220- Arg (type = str , help = "Label or the value to be predicted" ),
221- )
222-
223- cls .config_set (
224- args ,
225- above ,
226- "features" ,
227- Arg (
228- nargs = "+" ,
229- required = True ,
230- type = Feature .load ,
231- action = list_action (Features ),
232- help = "Features to train on" ,
233294 ),
234- )
235-
236- for param in inspect .signature (cls .SCIKIT_MODEL ).parameters .values ():
237- # TODO if param.default is an array then Args needs to get a
238- # nargs="+"
239- cls .config_set (
240- args ,
241- above ,
242- param .name ,
243- Arg (
244- type = cls .type_for (param ),
245- default = NoDefaultValue
246- if param .default == inspect ._empty
247- else param .default ,
248- ),
249- )
250- return args
251-
252- @classmethod
253- def config (cls , config , * above ):
254- params = dict (
255- directory = cls .config_get (config , above , "directory" ),
256- predict = cls .config_get (config , above , "predict" ),
257- features = cls .config_get (config , above , "features" ),
258- )
259- for name in inspect .signature (cls .SCIKIT_MODEL ).parameters .keys ():
260- params [name ] = cls .config_get (config , above , name )
261- return cls .CONFIG (** params )
295+ "predict" : (str , field ("Label or the value to be predicted" )),
296+ "features" : (Features , field ("Features to train on" )),
297+ },
298+ )
262299
263300 dffml_cls_ctx = type (
264301 name + "ModelContext" ,
@@ -273,12 +310,11 @@ def config(cls, config, *above):
273310 "CONFIG" : dffml_config ,
274311 "CONTEXT" : dffml_cls_ctx ,
275312 "SCIKIT_MODEL" : cls ,
276- "args" : args ,
277- "config" : config ,
278313 },
279314 )
280315 # Add the ENTRY_POINT_ORIG_LABEL
281316 dffml_cls = entry_point (entry_point_name )(dffml_cls )
282317
318+ setattr (sys .modules [__name__ ], dffml_config .__qualname__ , dffml_config )
283319 setattr (sys .modules [__name__ ], dffml_cls_ctx .__qualname__ , dffml_cls_ctx )
284320 setattr (sys .modules [__name__ ], dffml_cls .__qualname__ , dffml_cls )
0 commit comments