Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: "3.10"
python-version: "3.11"

- name: Install dependencies
run: |
Expand Down
11 changes: 8 additions & 3 deletions orangecontrib/explain/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,17 @@ def _join_shap_values(
of lists with the np.ndarray for each class, when explaining regression,
the result is the list of one np.ndarrays.
"""
if isinstance(shap_values[0], np.ndarray):
# regression
shape = shap_values[0].shape
if len(shape) == 1 or (len(shape) == 2 and shape[0] == 1):
# regression and xgb with two classes
return [np.vstack(shap_values)]
else:
# classification
return [np.vstack(s) for s in zip(*shap_values)]
if len(shape) == 3:
transformed = [(np.squeeze(v, axis=0)).T for v in shap_values]
else:
transformed = [v.T for v in shap_values]
return [np.vstack(s) for s in zip(*transformed)]


def _explain_trees(
Expand Down
7 changes: 5 additions & 2 deletions orangecontrib/explain/inspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import scipy.sparse as sp
from sklearn.inspection import partial_dependence
from sklearn.utils import Tags, TargetTags

from Orange.base import Model
from Orange.classification import Model as ClsModel
Expand Down Expand Up @@ -202,11 +203,13 @@ def dummy_fit(*_, **__):
model.fit = dummy_fit
model.fit_ = dummy_fit
if model.domain.class_var.is_discrete:
model._estimator_type = "classifier"
model.classes_ = np.array(model.domain.class_var.values)
estimator_type = "classifier"
else:
model._estimator_type = "regressor"
estimator_type = "regressor"

model.__sklearn_tags__ = lambda: Tags(estimator_type=estimator_type,
target_tags=TargetTags(required=True))
progress_callback(0.1)

dep = partial_dependence(model,
Expand Down
7 changes: 2 additions & 5 deletions orangecontrib/explain/widgets/owice.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
create_annotated_table
from Orange.widgets.utils.concurrent import TaskState, ConcurrentWidgetMixin
from Orange.widgets.utils.itemmodels import VariableListModel, DomainModel
from Orange.widgets.utils.multi_target import check_multiple_targets_input
from Orange.widgets.utils.sql import check_sql_input
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.visualize.owdistributions import LegendItem
Expand All @@ -37,10 +38,6 @@

from orangecontrib.explain.inspection import individual_condition_expectation
from orangewidget.utils.visual_settings_dlg import VisualSettingsDialog
try:
from Orange.widgets.utils.multi_target import check_multiple_targets_input
except ImportError:
check_multiple_targets_input = lambda f: f


class RunnerResults(SimpleNamespace):
Expand Down Expand Up @@ -734,7 +731,7 @@ def _apply_feature_sorting(self):
if self.order_by_importance:
def compute_score(feature):
values = self.__results_avgs[feature][self.target_index]
return -np.sum(np.abs(values - np.mean(values)))
return float(-np.sum(np.abs(values - np.mean(values))))

try:
if self.__results_avgs is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,19 @@ def test_x_label(self):
self.send_signal(self.widget.Inputs.model, self.rf_cls)
self.wait_until_finished()
label: QGraphicsTextItem = self.widget.plot.bottom_axis.label
self.assertEqual(label.toPlainText(), "Decrease in AUC ")
self.assertIn("Decrease in AUC", label.toPlainText())

self.send_signal(self.widget.Inputs.data, self.housing)
self.send_signal(self.widget.Inputs.model, self.rf_reg)
self.wait_until_finished()
label: QGraphicsTextItem = self.widget.plot.bottom_axis.label
self.assertEqual(label.toPlainText(), "Decrease in R2 ")
self.assertIn("Decrease in R2", label.toPlainText())

score_cb: QComboBox = self.widget._score_combo
simulate.combobox_activate_item(score_cb, "MSE")
self.wait_until_finished()
label: QGraphicsTextItem = self.widget.plot.bottom_axis.label
self.assertEqual(label.toPlainText(), "Increase in MSE ")
self.assertIn("Increase in MSE", label.toPlainText())

@unittest.mock.patch("orangecontrib.explain.widgets."
"owpermutationimportance.OWPermutationImportance.run")
Expand Down
22 changes: 10 additions & 12 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,16 @@
]

INSTALL_REQUIRES = [
"AnyQt",
# shap's requirement, force users for numba to get updated because compatibility
# issues with numpy - completely remove this pin after october 2024
"numba >=0.58",
"numpy",
"Orange3 >=3.36.2",
"orange-canvas-core >=0.1.30",
"orange-widget-base >=4.22.0",
"pyqtgraph",
"scipy",
"shap==0.42.1",
"scikit-learn>=1.3.0",
"AnyQt>=0.2.0",
"Orange3>=3.39.0",
"orange-canvas-core>=0.2.5",
"orange-widget-base>=4.25.0",
"pandas>=2.2.2",
"scikit-learn>=1.6.0",
"scipy>=1.13.0",
"pyqtgraph>=0.13.1",
"numpy>=2.0.0",
"shap>=0.50.0",
]

EXTRAS_REQUIRE = {
Expand Down
15 changes: 9 additions & 6 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ deps =
{env:PYQT_PYPI_NAME:PyQt5}=={env:PYQT_PYPI_VERSION:5.15.*}
{env:WEBENGINE_PYPI_NAME:PyQtWebEngine}=={env:WEBENGINE_PYPI_VERSION:5.15.*}
xgboost
oldest: orange3==3.36.2
oldest: orange-canvas-core==0.1.30
oldest: orange-widget-base==4.22.0
oldest: pandas==1.4.0
oldest: scikit-learn==1.3.0
oldest: scipy==1.9.0
oldest: orange3==3.39.0
oldest: orange-canvas-core==0.2.5
oldest: orange-widget-base==4.25.0
oldest: pandas~=2.2.2
oldest: scikit-learn~=1.6.0
oldest: scipy~=1.13.0
oldest: xgboost==2.0.0
oldest: pyqtgraph==0.13.1
oldest: numpy~=2.0.0
oldest: shap==0.50.0
latest: https://github.com/biolab/orange3/archive/refs/heads/master.zip#egg=orange3
latest: https://github.com/biolab/orange-canvas-core/archive/refs/heads/master.zip#egg=orange-canvas-core
latest: https://github.com/biolab/orange-widget-base/archive/refs/heads/master.zip#egg=orange-widget-base
Expand Down
Loading