Skip to content

Commit 4081fbe

Browse files
authored
Merge pull request #10 from PrimozGodec/fix-calibration
Temporarily fix explainer to work with calibration modelom
2 parents 8766100 + ca754f0 commit 4081fbe

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

orangecontrib/explain/explainer.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,7 @@ def _explain_trees(
102102
for i in range(0, len(data_sample), batch_size):
103103
progress_callback(i / len(data_sample))
104104
batch = data_sample.X[i : i + batch_size]
105-
shap_values.append(
106-
explainer.shap_values(batch, check_additivity=False)
107-
)
105+
shap_values.append(explainer.shap_values(batch, check_additivity=False))
108106

109107
shap_values = _join_shap_values(shap_values)
110108
base_value = explainer.expected_value
@@ -152,7 +150,9 @@ def _explain_other_models(
152150
for i, row in enumerate(data_sample.X):
153151
progress_callback(i / len(data_sample))
154152
shap_values.append(
155-
explainer.shap_values(row, nsamples=100, silent=True, l1_reg="num_features(90)")
153+
explainer.shap_values(
154+
row, nsamples=100, silent=True, l1_reg="num_features(90)"
155+
)
156156
)
157157
return (
158158
_join_shap_values(shap_values),
@@ -205,8 +205,24 @@ def compute_shap_values(
205205
progress_callback = dummy_callback
206206
progress_callback(0, "Computing explanation ...")
207207

208-
data_transformed = model.data_to_model_domain(data)
209-
reference_data_transformed = model.data_to_model_domain(reference_data)
208+
#### workaround for bug with calibration
209+
#### remove when fixed
210+
from Orange.classification import (
211+
ThresholdClassifier,
212+
CalibratedClassifier,
213+
)
214+
215+
trans_model = model
216+
while isinstance(
217+
trans_model, (ThresholdClassifier, CalibratedClassifier)
218+
):
219+
trans_model = trans_model.base_model
220+
#### end of workaround for bug with calibration
221+
222+
data_transformed = trans_model.data_to_model_domain(data)
223+
reference_data_transformed = trans_model.data_to_model_domain(
224+
reference_data
225+
)
210226

211227
shap_values, sample_mask, base_value = _explain_trees(
212228
model,
@@ -422,7 +438,8 @@ def explain_predictions(
422438
Domain(data.domain.attributes, None, data.domain.metas)
423439
)
424440
predictions = model(
425-
classless_data, model.Probs if model.domain.class_var.is_discrete else model.Value
441+
classless_data,
442+
model.Probs if model.domain.class_var.is_discrete else model.Value,
426443
)
427444
# for regression - predictions array is 1d transform it shape N x 1
428445
if predictions.ndim == 1:

orangecontrib/explain/tests/test_explainer.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
import inspect
12
import unittest
23

34
import numpy as np
5+
import pkg_resources
46

57
from Orange.classification import (
68
LogisticRegressionLearner,
79
RandomForestLearner,
810
SGDClassificationLearner,
911
SVMLearner,
1012
TreeLearner,
13+
ThresholdLearner,
1114
)
1215
from Orange.data import Table, Domain
1316
from Orange.regression import LinearRegressionLearner
14-
from Orange.tests.test_classification import LearnerAccessibility
15-
from Orange.tests import test_regression
17+
from Orange.tests import test_regression, test_classification
1618
from Orange.widgets.data import owcolor
1719
from orangecontrib.explain.explainer import (
1820
compute_colors,
@@ -159,12 +161,17 @@ def test_class_not_predicted(self):
159161
# missing class has all shap values 0
160162
self.assertTrue(not np.any(shap_values[2].sum()))
161163

162-
@unittest.skip("Enable when learners fixed")
163164
def test_all_classifiers(self):
164165
""" Test explanation for all classifiers """
165-
for learner in LearnerAccessibility.all_learners(None):
166+
for learner in test_classification.all_learners():
166167
with self.subTest(learner.name):
167-
model = learner(self.iris)
168+
if learner == ThresholdLearner:
169+
# ThresholdLearner require binary class
170+
continue
171+
kwargs = {}
172+
if "base_learner" in inspect.signature(learner).parameters:
173+
kwargs = {"base_learner": LogisticRegressionLearner()}
174+
model = learner(**kwargs)(self.iris)
168175
shap_values, _, _, _ = compute_shap_values(
169176
model, self.iris, self.iris
170177
)
@@ -176,7 +183,7 @@ def test_all_classifiers(self):
176183

177184
@unittest.skipIf(
178185
not hasattr(test_regression, "all_learners"),
179-
"all_learners not available in Orange < 3.26"
186+
"all_learners not available in Orange < 3.26",
180187
)
181188
def test_all_regressors(self):
182189
""" Test explanation for all regressors """
@@ -569,6 +576,16 @@ def test_no_class(self):
569576
self.assertTupleEqual(self.iris.X.shape, shap_values[0].shape)
570577
self.assertTupleEqual((len(self.iris),), sample_mask.shape)
571578

579+
def test_remove_calibration_workaround(self):
580+
"""
581+
When this test start to fail remove the workaround in
582+
explainer.py-207:220 if allready fixed - revert the pullrequest
583+
that adds those lines.
584+
"""
585+
self.assertGreater(
586+
"3.29.0", pkg_resources.get_distribution("orange3").version
587+
)
588+
572589

573590
if __name__ == "__main__":
574591
unittest.main()

0 commit comments

Comments
 (0)