Skip to content

Commit 64e351e

Browse files
authored
Merge pull request #59 from nipype/enh/pydra23
initial attempt to work with pydra 0.23+
2 parents a3722fc + 0f3df04 commit 64e351e

File tree

6 files changed

+56
-15
lines changed

6 files changed

+56
-15
lines changed

.gitignore

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,3 +133,9 @@ dmypy.json
133133

134134
# pycharm
135135
.idea/
136+
137+
# Venvs
138+
*.venv
139+
140+
# Generated messages
141+
/messages

pydra_ml/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,7 @@ def set_logger_level(lgr, level):
3535
set_logger_level(lgr, os.environ.get("PYDRAML_LOG_LEVEL", logging.INFO))
3636
FORMAT = "%(asctime)-15s [%(levelname)8s] %(message)s"
3737
logging.basicConfig(format=FORMAT)
38+
39+
from . import _version
40+
41+
__version__ = _version.get_versions()["version"]

pydra_ml/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
7171
messengers=FileMessenger(),
7272
messenger_args={"message_dir": os.path.join(os.getcwd(), "messages")},
7373
)
74-
wf.split(["clf_info", "permute"])
74+
wf.split(clf_info=inputs["clf_info"], permute=inputs["permute"])
7575
wf.add(
7676
read_file_pdt(
7777
name="readcsv",
@@ -102,7 +102,7 @@ def gen_workflow(inputs, cache_dir=None, cache_locations=None):
102102
permute=wf.lzin.permute,
103103
)
104104
)
105-
wf.fit_clf.split("split_index")
105+
wf.fit_clf.split(split_index=wf.gensplit.lzout.split_indices)
106106
wf.add(
107107
calc_metric_pdt(
108108
name="metric", output=wf.fit_clf.lzout.output, metrics=wf.lzin.metrics

pydra_ml/report.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,15 @@
33
import pickle
44
import warnings
55

6+
import matplotlib
67
import matplotlib.pyplot as plt
78
import numpy as np
89
import pandas as pd
910
import seaborn as sns
1011
from sklearn.metrics import accuracy_score, explained_variance_score
1112

13+
matplotlib.use("Agg")
14+
1215

1316
def save_obj(obj, path):
1417
with open(path, "wb") as f:
@@ -97,9 +100,9 @@ def plot_summary(summary, output_dir=None, filename="shap_plot", plot_top_n_shap
97100
# plot without all bootstrapping values
98101
summary = summary[["mean", "std", "min", "max"]]
99102
num_features = len(list(summary.index))
100-
if (plot_top_n_shap != 1 and type(plot_top_n_shap) == float) or type(
103+
if (plot_top_n_shap != 1 and type(plot_top_n_shap) is float) or type(
101104
plot_top_n_shap
102-
) == int:
105+
) is int:
103106
# if plot_top_n_shap != 1.0 but includes 1 (int)
104107
if plot_top_n_shap <= 0:
105108
raise ValueError(
@@ -223,7 +226,7 @@ def gen_report_shap_class(results, output_dir="./", plot_top_n_shap=16):
223226
f"""There were no {quadrant.upper()}s, this will output NaNs
224227
in the csv and figure for this split column"""
225228
)
226-
shaps_i_quadrant = shaps_i[
229+
shaps_i_quadrant = np.array(shaps_i)[
227230
indexes.get(quadrant)
228231
] # shape (P, F) P prediction x F feature_names
229232
abs_weighted_shap_values = np.abs(shaps_i_quadrant) * split_performance
@@ -325,7 +328,7 @@ def gen_report_shap_regres(results, output_dir="./", plot_top_n_shap=16):
325328
f"""There were no {quadrant.upper()}s, this will
326329
output NaNs in the csv and figure for this split column"""
327330
)
328-
shaps_i_quadrant = shaps_i[
331+
shaps_i_quadrant = np.array(shaps_i)[
329332
indexes.get(quadrant)
330333
] # shape (P, F) P prediction x F feature_names
331334
abs_weighted_shap_values = np.abs(shaps_i_quadrant) * split_performance

pydra_ml/tasks.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
#!/usr/bin/env python
22

3+
import typing as ty
4+
5+
from pydra.utils.hash import Cache, register_serializer
6+
from sklearn.pipeline import Pipeline
7+
8+
9+
@register_serializer
10+
def bytes_repr_Pipeline(obj: Pipeline, cache: Cache):
11+
yield str(obj).encode()
12+
313

414
def read_file(filename, x_indices=None, target_vars=None, group=None):
515
"""Read a CSV data file
@@ -92,7 +102,7 @@ def to_instance(clf_info):
92102

93103
train_index, test_index = train_test_split[split_index]
94104
y = y.ravel()
95-
if type(X[0][0]) == str:
105+
if type(X[0][0]) is str:
96106
# it's loaded as bytes, so we need to decode as utf-8
97107
X = np.array([str.encode(n[0]).decode("utf-8") for n in X])
98108
if permute:
@@ -126,7 +136,27 @@ def calc_metric(output, metrics):
126136
return score, output
127137

128138

129-
def get_feature_importance(permute, model, gen_feature_importance=True):
139+
def get_feature_importance(
140+
*,
141+
permute: bool,
142+
model: ty.Tuple[Pipeline, list, list],
143+
gen_feature_importance: bool = True,
144+
):
145+
"""Compute feature importance for the model
146+
147+
Parameters
148+
----------
149+
permute : bool
150+
Whether or not to run the model in permuted mode
151+
model : tuple(sklearn.pipeline.Pipeline, list, list)
152+
The model to compute feature importance for
153+
gen_feature_importance : bool
154+
Whether or not to generate the feature importance
155+
Returns
156+
-------
157+
list
158+
List of feature importance
159+
"""
130160
if permute or not gen_feature_importance:
131161
return []
132162
pipeline, train_index, test_index = model
@@ -172,7 +202,7 @@ def get_feature_importance(permute, model, gen_feature_importance=True):
172202
pipeline_steps.coefs_
173203
pipeline_steps.coef_
174204
175-
Please add correct method in tasks.py or if inexistent,
205+
Please add correct method in tasks.py or if non-existent,
176206
set gen_feature_importance to false in the spec file.
177207
178208
This is the error that was returned by sklearn:\n\t{e}\n
@@ -224,7 +254,9 @@ def get_shap(X, permute, model, gen_shap=False, nsamples="auto", l1_reg="aic"):
224254
import shap
225255

226256
explainer = shap.KernelExplainer(pipe.predict, shap.kmeans(X[train_index], 5))
227-
shaps = explainer.shap_values(X[test_index], nsamples=nsamples, l1_reg=l1_reg)
257+
shaps = explainer.shap_values(
258+
X[test_index], nsamples=nsamples, l1_reg=l1_reg, silent=True
259+
)
228260
return shaps
229261

230262

setup.cfg

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ classifiers =
2626
[options]
2727
python_requires = >= 3.8
2828
install_requires =
29-
pydra == 0.22.0
29+
pydra >= 0.23.0-alpha
3030
psutil
3131
scikit-learn
3232
seaborn
@@ -35,11 +35,9 @@ install_requires =
3535

3636
test_requires =
3737
pytest >= 4.4.0
38-
pytest-cov
3938
pytest-env
4039
pytest-xdist
4140
pytest-rerunfailures
42-
codecov
4341
packages = find:
4442
include_package_data = True
4543

@@ -58,11 +56,9 @@ docs =
5856
%(doc)s
5957
test =
6058
pytest >= 4.4.0
61-
pytest-cov
6259
pytest-env
6360
pytest-xdist
6461
pytest-rerunfailures
65-
codecov
6662
tests =
6763
%(test)s
6864
dev =

0 commit comments

Comments
 (0)