Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 6415473

Browse files
committed
model: scikit: Use auto args and config
Fixes: #285 Signed-off-by: John Andersen <[email protected]>
1 parent 2c606bf commit 6415473

File tree

2 files changed

+135
-98
lines changed

2 files changed

+135
-98
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3333
- Function to create a config class dynamically, analogous to `make_dataclass`
3434
### Changed
3535
- CLI tests and integration tests derive from `AsyncExitStackTestCase`
36+
- SciKit models now use the auto args and config methods.
3637
### Fixed
3738
- Correctly identify when functions decorated with `op` use `self` to reference
3839
the `OperationImplementationContext`.

model/scikit/dffml_model_scikit/scikit_models.py

Lines changed: 134 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import sys
88
import ast
99
import inspect
10+
import dataclasses
1011
from collections import namedtuple
11-
from typing import Dict
12+
from typing import Dict, Optional, Tuple, Type, Any
1213

1314
from sklearn.neural_network import MLPClassifier
1415
from sklearn.neighbors import KNeighborsClassifier
@@ -43,6 +44,7 @@
4344
Ridge,
4445
)
4546

47+
from dffml.base import make_config, field
4648
from dffml.util.cli.arg import Arg
4749
from dffml.util.entrypoint import entry_point
4850
from dffml_model_scikit.scikit_base import Scikit, ScikitContext
@@ -70,6 +72,115 @@ class NoDefaultValue:
7072
pass
7173

7274

75+
class ParameterNotInDocString(Exception):
76+
"""
77+
Raised when a scikit class has a parameter in its ``__init__`` which was not
78+
present in it's docstring. Therefore we have no typing information for it.
79+
"""
80+
81+
82+
def scikit_get_default(type_str):
83+
if not "default" in type_str:
84+
return dataclasses.MISSING
85+
type_str = type_str[type_str.index("default") :]
86+
type_str = type_str.replace("default", "")
87+
type_str = type_str.replace(")", "")
88+
type_str = type_str.replace("=", "")
89+
type_str = type_str.replace('"', "")
90+
type_str = type_str.replace("'", "")
91+
type_str = type_str.strip()
92+
if type_str == "None":
93+
return None
94+
return type_str
95+
96+
97+
SCIKIT_DOCS_TYPE_MAP = {
98+
"int": int,
99+
"integer": int,
100+
"str": str,
101+
"string": str,
102+
"float": float,
103+
"dict": dict,
104+
"bool": bool,
105+
}
106+
107+
108+
def scikit_doc_to_field(type_str, param):
109+
default = param.default
110+
if default is inspect.Parameter.empty:
111+
default = scikit_get_default(type_str)
112+
113+
type_cls = Any
114+
115+
# Set of choices
116+
if "{'" in type_str and "'}" in type_str:
117+
type_cls = str
118+
elif "{" in type_str and "}" in type_str:
119+
type_cls = int
120+
if "." in type_str:
121+
type_cls = float
122+
else:
123+
type_split = list(
124+
map(lambda x: x.lower(), type_str.replace(",", "").split())
125+
)
126+
for scikit_type_name, python_type in SCIKIT_DOCS_TYPE_MAP.items():
127+
if scikit_type_name in type_split:
128+
type_cls = python_type
129+
130+
if type_cls == Any and default != None:
131+
type_cls = type(default)
132+
133+
return type_cls, field(type_str, default=default)
134+
135+
136+
def mkscikit_config_cls(
137+
name: str,
138+
cls: Type,
139+
properties: Optional[Dict[str, Tuple[Type, field]]] = None,
140+
):
141+
"""
142+
Given a scikit class, read its docstring and ``__init__`` parameters to
143+
generate a config class with properties containing the correct types,
144+
and default values.
145+
"""
146+
if properties is None:
147+
properties = {}
148+
149+
parameters = inspect.signature(cls).parameters
150+
docstring = inspect.getdoc(cls)
151+
152+
docparams = {}
153+
154+
# Parse parameters and their datatypes from docstring
155+
last_param_name = None
156+
for line in docstring.split("\n"):
157+
if not ":" in line:
158+
continue
159+
param_name, dtypes = line.split(":", maxsplit=1)
160+
param_name = param_name.strip()
161+
dtypes = dtypes.strip()
162+
if not param_name in parameters or param_name in docparams:
163+
continue
164+
docparams[param_name] = dtypes
165+
last_param_name = param_name
166+
167+
# Ensure all required parameters are present in docstring
168+
for param_name, param in parameters.items():
169+
if param_name in ["args", "kwargs"]:
170+
continue
171+
if not param_name in docparams:
172+
raise ParameterNotInDocString(
173+
f"{param_name} for {cls.__qualname__}"
174+
)
175+
properties[param_name] = scikit_doc_to_field(
176+
docparams[param_name], param
177+
)
178+
179+
return make_config(
180+
name, [tuple([key] + list(value)) for key, value in properties.items()]
181+
)
182+
183+
73184
for entry_point_name, name, cls, applicable_features_function in [
74185
(
75186
"scikitknn",
@@ -129,15 +240,10 @@ class NoDefaultValue:
129240
ExtraTreesClassifier,
130241
applicable_features,
131242
),
132-
(
133-
"scikitbgc",
134-
"BaggingClassifier",
135-
BaggingClassifier,
136-
applicable_features,
137-
),
138-
("scikiteln", "ElasticNet", ElasticNet, applicable_features,),
139-
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features,),
140-
("scikitlas", "Lasso", Lasso, applicable_features,),
243+
("scikitbgc", "BaggingClassifier", BaggingClassifier, applicable_features),
244+
("scikiteln", "ElasticNet", ElasticNet, applicable_features),
245+
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features),
246+
("scikitlas", "Lasso", Lasso, applicable_features),
141247
("scikitard", "ARDRegression", ARDRegression, applicable_features),
142248
("scikitrsc", "RANSACRegressor", RANSACRegressor, applicable_features),
143249
("scikitbnb", "BernoulliNB", BernoulliNB, applicable_features),
@@ -170,95 +276,26 @@ class NoDefaultValue:
170276
("scikitlars", "Lars", Lars, applicable_features),
171277
]:
172278

173-
parameters = inspect.signature(cls).parameters
174-
defaults = [
175-
os.path.join(
176-
os.path.expanduser("~"),
177-
".cache",
178-
"dffml",
179-
f"scikit-{entry_point_name}",
180-
),
181-
NoDefaultValue,
182-
] + [
183-
param.default
184-
for name, param in parameters.items()
185-
if param.default != inspect._empty
186-
]
187-
dffml_config = namedtuple(
279+
dffml_config = mkscikit_config_cls(
188280
name + "ModelConfig",
189-
["directory", "predict", "features"]
190-
+ [
191-
param.name
192-
for _, param in parameters.items()
193-
if param.default != inspect._empty
194-
],
195-
defaults=defaults,
196-
)
197-
198-
setattr(sys.modules[__name__], dffml_config.__qualname__, dffml_config)
199-
200-
@classmethod
201-
def args(cls, args, *above) -> Dict[str, Arg]:
202-
cls.config_set(
203-
args,
204-
above,
205-
"directory",
206-
Arg(
207-
default=os.path.join(
208-
os.path.expanduser("~"),
209-
".cache",
210-
"dffml",
211-
f"scikit-{entry_point_name}",
281+
cls,
282+
properties={
283+
"directory": (
284+
str,
285+
field(
286+
"Directory where state should be saved",
287+
default=os.path.join(
288+
os.path.expanduser("~"),
289+
".cache",
290+
"dffml",
291+
f"scikit-{entry_point_name}",
292+
),
212293
),
213-
help="Directory where state should be saved",
214-
),
215-
)
216-
cls.config_set(
217-
args,
218-
above,
219-
"predict",
220-
Arg(type=str, help="Label or the value to be predicted"),
221-
)
222-
223-
cls.config_set(
224-
args,
225-
above,
226-
"features",
227-
Arg(
228-
nargs="+",
229-
required=True,
230-
type=Feature.load,
231-
action=list_action(Features),
232-
help="Features to train on",
233294
),
234-
)
235-
236-
for param in inspect.signature(cls.SCIKIT_MODEL).parameters.values():
237-
# TODO if param.default is an array then Args needs to get a
238-
# nargs="+"
239-
cls.config_set(
240-
args,
241-
above,
242-
param.name,
243-
Arg(
244-
type=cls.type_for(param),
245-
default=NoDefaultValue
246-
if param.default == inspect._empty
247-
else param.default,
248-
),
249-
)
250-
return args
251-
252-
@classmethod
253-
def config(cls, config, *above):
254-
params = dict(
255-
directory=cls.config_get(config, above, "directory"),
256-
predict=cls.config_get(config, above, "predict"),
257-
features=cls.config_get(config, above, "features"),
258-
)
259-
for name in inspect.signature(cls.SCIKIT_MODEL).parameters.keys():
260-
params[name] = cls.config_get(config, above, name)
261-
return cls.CONFIG(**params)
295+
"predict": (str, field("Label or the value to be predicted")),
296+
"features": (Features, field("Features to train on")),
297+
},
298+
)
262299

263300
dffml_cls_ctx = type(
264301
name + "ModelContext",
@@ -273,12 +310,11 @@ def config(cls, config, *above):
273310
"CONFIG": dffml_config,
274311
"CONTEXT": dffml_cls_ctx,
275312
"SCIKIT_MODEL": cls,
276-
"args": args,
277-
"config": config,
278313
},
279314
)
280315
# Add the ENTRY_POINT_ORIG_LABEL
281316
dffml_cls = entry_point(entry_point_name)(dffml_cls)
282317

318+
setattr(sys.modules[__name__], dffml_config.__qualname__, dffml_config)
283319
setattr(sys.modules[__name__], dffml_cls_ctx.__qualname__, dffml_cls_ctx)
284320
setattr(sys.modules[__name__], dffml_cls.__qualname__, dffml_cls)

0 commit comments

Comments
 (0)