Skip to content

Commit a44e535

Browse files
authored
Merge pull request #29 from satra/enh-pipeline
Add support for basic scikit-learn pipelines
2 parents 273e2b0 + 640d8fa commit a44e535

File tree

6 files changed

+79
-27
lines changed

6 files changed

+79
-27
lines changed

README.md

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,10 @@ scale across a set of classifiers and metrics. It will also use Pydra's caching
1111
to not redo model training and evaluation when new metrics are added, or when
1212
number of iterations (`n_splits`) is increased.
1313

14-
Upcoming features:
15-
1. Improve output report containing [SHAP](https://github.com/slundberg/shap)
14+
1. Output report contains [SHAP](https://github.com/slundberg/shap)
1615
feature analysis.
17-
2. Allow for comparing scikit-learn pipelines.
18-
3. Test on scikit-learn compatible classifiers
16+
2. Allows for comparing *some* scikit-learn pipelines in addition to base
17+
classifiers.
1918

2019
### Installation
2120

@@ -109,6 +108,16 @@ This is a list of classifiers from scikit learn and uses an array to encode:
109108
when param grid is provided and default classifier parameters are not changed,
110109
then an empty dictionary **MUST** be provided as parameter 3.
111110

111+
This can also be embedded as a list indicating a scikit-learn Pipeline. For
112+
example:
113+
114+
```
115+
[ ["sklearn.impute", "SimpleImputer"],
116+
["sklearn.preprocessing", "StandardScaler"],
117+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]
118+
]
119+
```
120+
112121
## Example specification:
113122

114123
```
@@ -121,17 +130,17 @@ then an empty dictionary **MUST** be provided as parameter 3.
121130
"test_size": 0.2,
122131
"clf_info": [
123132
["sklearn.ensemble", "AdaBoostClassifier"],
124-
["sklearn.naive_bayes", "GaussianNB"],
125133
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}],
126-
["sklearn.ensemble", "RandomForestClassifier", {"n_estimators": 100}],
127-
["sklearn.ensemble", "ExtraTreesClassifier", {"n_estimators": 100, "class_weight": "balanced"}],
128-
["sklearn.linear_model", "LogisticRegressionCV", {"solver": "liblinear", "penalty": "l1"}],
129134
["sklearn.neural_network", "MLPClassifier", {"alpha": 1, "max_iter": 1000}],
130135
["sklearn.svm", "SVC", {"probability": true},
131136
[{"kernel": ["rbf", "linear"], "C": [1, 10, 100, 1000]}]],
132137
["sklearn.neighbors", "KNeighborsClassifier", {},
133138
[{"n_neighbors": [3, 5, 7, 9, 11, 13, 15, 17, 19],
134-
"weights": ["uniform", "distance"]}]]
139+
"weights": ["uniform", "distance"]}]],
140+
[ ["sklearn.impute", "SimpleImputer"],
141+
["sklearn.preprocessing", "StandardScaler"],
142+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]
143+
]
135144
],
136145
"permute": [true, false],
137146
"gen_shap": true,

long-spec.json.sample

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
"clf_info": [
99
["sklearn.ensemble", "AdaBoostClassifier"],
1010
["sklearn.naive_bayes", "GaussianNB"],
11-
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}],
11+
[ ["sklearn.impute", "SimpleImputer"],
12+
["sklearn.preprocessing", "StandardScaler"],
13+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]],
1214
["sklearn.ensemble", "RandomForestClassifier", {"n_estimators": 100}],
1315
["sklearn.ensemble", "ExtraTreesClassifier", {"n_estimators": 100, "class_weight": "balanced"}],
1416
["sklearn.linear_model", "LogisticRegressionCV", {"solver": "liblinear", "penalty": "l1"}],

pydra_ml/report.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,10 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
8787
indexes_all = {}
8888

8989
for model_results in results:
90-
model_name = model_results[0].get("ml_wf.clf_info")[1]
90+
model_name = model_results[0].get("ml_wf.clf_info")
91+
if isinstance(model_name[0], list):
92+
model_name = model_name[-1]
93+
model_name = model_name[1]
9194
indexes_all[model_name] = []
9295
shaps = model_results[
9396
1
@@ -179,7 +182,10 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
179182
indexes_all = {}
180183

181184
for model_results in results:
182-
model_name = model_results[0].get("ml_wf.clf_info")[1]
185+
model_name = model_results[0].get("ml_wf.clf_info")
186+
if isinstance(model_name[0], list):
187+
model_name = model_name[-1]
188+
model_name = model_name[1]
183189
indexes_all[model_name] = []
184190
shaps = model_results[
185191
1
@@ -308,7 +314,17 @@ def gen_report(
308314
score = val[1].output.score
309315
if not isinstance(score, list):
310316
score = [score]
311-
name = val[0][prefix + ".clf_info"][1].split("Classifier")[0]
317+
318+
clf = val[0][prefix + ".clf_info"]
319+
if isinstance(clf[0], list):
320+
clf = clf[-1][1]
321+
else:
322+
clf = clf[1]
323+
if "Classifier" in clf:
324+
name = clf.split("Classifier")[0]
325+
else:
326+
name = clf.split("Regressor")[0]
327+
name = name.split("CV")[0]
312328
permute = val[0][prefix + ".permute"]
313329
for scoreval in score:
314330
for idx, metric in enumerate(metrics):

pydra_ml/tasks.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,21 +62,35 @@ def train_test_kernel(X, y, train_test_split, split_index, clf_info, permute):
6262
:param permute: whether to run it in permuted mode or not
6363
:return: outputs, trained classifier with sample indices
6464
"""
65-
from sklearn.preprocessing import StandardScaler
6665
from sklearn.pipeline import Pipeline
6766
import numpy as np
6867

69-
mod = __import__(clf_info[0], fromlist=[clf_info[1]])
70-
params = {}
71-
if len(clf_info) > 2:
72-
params = clf_info[2]
73-
clf = getattr(mod, clf_info[1])(**params)
74-
if len(clf_info) == 4:
75-
from sklearn.model_selection import GridSearchCV
68+
def to_instance(clf_info):
69+
mod = __import__(clf_info[0], fromlist=[clf_info[1]])
70+
params = {}
71+
if len(clf_info) > 2:
72+
params = clf_info[2]
73+
clf = getattr(mod, clf_info[1])(**params)
74+
if len(clf_info) == 4:
75+
from sklearn.model_selection import GridSearchCV
76+
77+
clf = GridSearchCV(clf, param_grid=clf_info[3])
78+
return clf
79+
80+
if isinstance(clf_info[0], list):
81+
# Process as a pipeline constructor
82+
steps = []
83+
for val in clf_info:
84+
step = to_instance(val)
85+
steps.append((val[1], step))
86+
pipe = Pipeline(steps)
87+
else:
88+
clf = to_instance(clf_info)
89+
from sklearn.preprocessing import StandardScaler
90+
91+
pipe = Pipeline([("std", StandardScaler()), (clf_info[1], clf)])
7692

77-
clf = GridSearchCV(clf, param_grid=clf_info[3])
7893
train_index, test_index = train_test_split[split_index]
79-
pipe = Pipeline([("std", StandardScaler()), (clf_info[1], clf)])
8094
y = y.ravel()
8195
if permute:
8296
pipe.fit(X[train_index], y[np.random.permutation(train_index)])

pydra_ml/tests/test_classifier.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
def test_classifier(tmpdir):
66
clfs = [
77
("sklearn.neural_network", "MLPClassifier", {"alpha": 1, "max_iter": 1000}),
8-
("sklearn.naive_bayes", "GaussianNB", {}),
8+
[
9+
["sklearn.impute", "SimpleImputer"],
10+
["sklearn.preprocessing", "StandardScaler"],
11+
["sklearn.naive_bayes", "GaussianNB", {}],
12+
],
913
]
1014
csv_file = os.path.join(os.path.dirname(__file__), "data", "breast_cancer.csv")
1115
inputs = {
@@ -32,7 +36,11 @@ def test_classifier(tmpdir):
3236

3337
def test_regressor(tmpdir):
3438
clfs = [
35-
("sklearn.neural_network", "MLPRegressor", {"alpha": 1, "max_iter": 1000}),
39+
[
40+
["sklearn.impute", "SimpleImputer"],
41+
["sklearn.preprocessing", "StandardScaler"],
42+
["sklearn.neural_network", "MLPRegressor", {"alpha": 1, "max_iter": 1000}],
43+
],
3644
(
3745
"sklearn.linear_model",
3846
"LinearRegression",
@@ -58,6 +66,6 @@ def test_regressor(tmpdir):
5866

5967
wf = gen_workflow(inputs, cache_dir=tmpdir)
6068
results = run_workflow(wf, "cf", {"n_procs": 1})
61-
assert results[0][0]["ml_wf.clf_info"][1] == "MLPRegressor"
69+
assert results[0][0]["ml_wf.clf_info"][-1][1] == "MLPRegressor"
6270
assert results[0][0]["ml_wf.permute"]
6371
assert results[0][1].output.score[0][0] < results[1][1].output.score[0][0]

short-spec.json.sample

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
"test_size": 0.2,
88
"clf_info": [
99
["sklearn.neural_network", "MLPClassifier", {"alpha": 1, "max_iter": 1000}],
10-
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]
10+
[ ["sklearn.impute", "SimpleImputer"],
11+
["sklearn.preprocessing", "StandardScaler"],
12+
["sklearn.tree", "DecisionTreeClassifier", {"max_depth": 5}]
13+
]
1114
],
1215
"permute": [false, true],
1316
"gen_shap": true,

0 commit comments

Comments
 (0)