1+ import inspect
12import unittest
23
34import numpy as np
5+ import pkg_resources
46
57from Orange .classification import (
68 LogisticRegressionLearner ,
79 RandomForestLearner ,
810 SGDClassificationLearner ,
911 SVMLearner ,
1012 TreeLearner ,
13+ ThresholdLearner ,
1114)
1215from Orange .data import Table , Domain
1316from Orange .regression import LinearRegressionLearner
14- from Orange .tests .test_classification import LearnerAccessibility
15- from Orange .tests import test_regression
17+ from Orange .tests import test_regression , test_classification
1618from Orange .widgets .data import owcolor
1719from orangecontrib .explain .explainer import (
1820 compute_colors ,
@@ -159,12 +161,17 @@ def test_class_not_predicted(self):
159161 # missing class has all shap values 0
160162 self .assertTrue (not np .any (shap_values [2 ].sum ()))
161163
162- @unittest .skip ("Enable when learners fixed" )
163164 def test_all_classifiers (self ):
164165 """ Test explanation for all classifiers """
165- for learner in LearnerAccessibility .all_learners (None ):
166+ for learner in test_classification .all_learners ():
166167 with self .subTest (learner .name ):
167- model = learner (self .iris )
168+ if learner == ThresholdLearner :
169+ # ThresholdLearner require binary class
170+ continue
171+ kwargs = {}
172+ if "base_learner" in inspect .signature (learner ).parameters :
173+ kwargs = {"base_learner" : LogisticRegressionLearner ()}
174+ model = learner (** kwargs )(self .iris )
168175 shap_values , _ , _ , _ = compute_shap_values (
169176 model , self .iris , self .iris
170177 )
@@ -176,7 +183,7 @@ def test_all_classifiers(self):
176183
177184 @unittest .skipIf (
178185 not hasattr (test_regression , "all_learners" ),
179- "all_learners not available in Orange < 3.26"
186+ "all_learners not available in Orange < 3.26" ,
180187 )
181188 def test_all_regressors (self ):
182189 """ Test explanation for all regressors """
@@ -569,6 +576,16 @@ def test_no_class(self):
569576 self .assertTupleEqual (self .iris .X .shape , shap_values [0 ].shape )
570577 self .assertTupleEqual ((len (self .iris ),), sample_mask .shape )
571578
579+ def test_remove_calibration_workaround (self ):
580+ """
581+ When this test start to fail remove the workaround in
582+ explainer.py-207:220 if allready fixed - revert the pullrequest
583+ that adds those lines.
584+ """
585+ self .assertGreater (
586+ "3.29.0" , pkg_resources .get_distribution ("orange3" ).version
587+ )
588+
572589
573590if __name__ == "__main__" :
574591 unittest .main ()
0 commit comments