Skip to content

Commit c00ac06

Browse files
authored
Fix FeatureInputComponent bool range crash (#328)
1 parent 36c7011 commit c00ac06

File tree

5 files changed

+41
-8
lines changed

5 files changed

+41
-8
lines changed

RELEASE_NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Release Notes
22

33

4+
## Version 0.5.6:
5+
6+
### Bug Fixes
7+
- Fix FeatureInputComponent range calculation for boolean columns (avoid np.round on bools) and add a regression test.
8+
49
## Version 0.5.5:
510

611
### Bug Fixes

explainerdashboard/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.5.5"
1+
__version__ = "0.5.6"
22

33
from .explainers import ClassifierExplainer, RegressionExplainer # noqa
44
from .dashboards import ExplainerDashboard, ExplainerHub, InlineExplainer # noqa

explainerdashboard/dashboard_components/overview_components.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from math import ceil
99

1010
import numpy as np
11+
from pandas.api.types import is_bool_dtype
1112

1213
from dash import html, dcc, Input, Output
1314
from dash.exceptions import PreventUpdate
@@ -1212,12 +1213,13 @@ def _generate_dash_input(self, col, onehot_cols, onehot_dict, cat_dict):
12121213
]
12131214
)
12141215
else:
1215-
min_range = np.round(
1216-
self.explainer.X[col][lambda x: x != self.explainer.na_fill].min(), 2
1217-
)
1218-
max_range = np.round(
1219-
self.explainer.X[col][lambda x: x != self.explainer.na_fill].max(), 2
1220-
)
1216+
col_values = self.explainer.X[col][lambda x: x != self.explainer.na_fill]
1217+
if is_bool_dtype(col_values):
1218+
min_range = int(col_values.min())
1219+
max_range = int(col_values.max())
1220+
else:
1221+
min_range = np.round(col_values.min(), 2)
1222+
max_range = np.round(col_values.max(), 2)
12211223
return html.Div(
12221224
[
12231225
dbc.Label(col),

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "explainerdashboard"
7-
version = "0.5.5"
7+
version = "0.5.6"
88
description = "Quickly build Explainable AI dashboards that show the inner workings of so-called \"blackbox\" machine learning models."
99
readme = "README.md"
1010
requires-python = ">=3.10"
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from sklearn.ensemble import RandomForestClassifier
2+
3+
from explainerdashboard import ClassifierExplainer
4+
from explainerdashboard.dashboard_components.overview_components import (
5+
FeatureInputComponent,
6+
)
7+
8+
9+
def test_feature_input_component_handles_bool_columns(classifier_data):
10+
X_train, y_train, X_test, y_test = classifier_data
11+
12+
X_train = X_train.copy()
13+
X_test = X_test.copy()
14+
15+
cutoff = X_train["Age"].median()
16+
X_train["is_older"] = X_train["Age"] > cutoff
17+
X_test["is_older"] = X_test["Age"] > cutoff
18+
19+
model = RandomForestClassifier(n_estimators=5, max_depth=2)
20+
model.fit(X_train, y_train)
21+
22+
explainer = ClassifierExplainer(model, X_test, y_test)
23+
component = FeatureInputComponent(explainer)
24+
25+
layout = component.layout()
26+
assert layout is not None

0 commit comments

Comments
 (0)