Skip to content

Commit 6950db9

Browse files
authored
Customize what-if input ranges and rounding (#333)
1 parent 1a5fc74 commit 6950db9

File tree

6 files changed

+58
-8
lines changed

6 files changed

+58
-8
lines changed

RELEASE_NOTES.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Release Notes
22

33

4+
## Version 0.5.7:
5+
6+
### Bug Fixes
7+
- Allow FeatureInputComponent (what-if inputs) to customize numeric ranges and rounding, and apply min/max/step to inputs.
8+
49
## Version 0.5.6:
510

611
### Bug Fixes

TODO.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66
- Rules: link an issue when possible; include size S/M/L; mark blockers.
77

88
**Now**
9-
- [S/M][Components][#277] whatif input range/rounding customization.
9+
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
1010

1111
**Next**
12-
- [M][Explainers][#273] categorical columns with NaNs: sorting and column preservation.
1312
- [S][Explainers][#270] Autogluon integration (coerce predict_proba to ndarray).
1413
- [M][Hub][#269] add_dashboard endpoint fails after first request (Flask blueprint lifecycle).
1514
- [M/L][Components][#262] add filters for random transaction selection in whatif tab.

explainerdashboard/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.5.6"
1+
__version__ = "0.5.7"
22

33
import logging
44
import sys

explainerdashboard/dashboard_components/overview_components.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from math import ceil
99

1010
import numpy as np
11+
import pandas as pd
1112
from pandas.api.types import is_bool_dtype
1213

1314
from dash import html, dcc, Input, Output
@@ -1091,6 +1092,8 @@ def __init__(
10911092
n_input_cols=4,
10921093
sort_features="shap",
10931094
fill_row_first=True,
1095+
feature_input_ranges=None,
1096+
round=2,
10941097
description=None,
10951098
**kwargs,
10961099
):
@@ -1116,6 +1119,9 @@ def __init__(
11161119
is 'shap' to sort by mean absolute shap value.
11171120
fill_row_first (bool, optional): if True most important features will
11181121
be on top row, if False they will be in most left column.
1122+
feature_input_ranges (dict, optional): dict mapping feature names to
1123+
(min, max) numeric ranges for input fields.
1124+
round (int, optional): number of decimals to round numeric ranges to.
11191125
description (str, optional): Tooltip to display when hover over
11201126
component title. When None default text is shown.
11211127
@@ -1131,6 +1137,8 @@ def __init__(
11311137
explainer, name="feature-input-index-" + self.name, **kwargs
11321138
)
11331139
self.index_name = "feature-input-index-" + self.name
1140+
self.feature_input_ranges = feature_input_ranges or {}
1141+
self.round = round
11341142

11351143
self._feature_callback_inputs = [
11361144
Input("feature-input-" + feature + "-input-" + self.name, "value")
@@ -1214,17 +1222,28 @@ def _generate_dash_input(self, col, onehot_cols, onehot_dict, cat_dict):
12141222
)
12151223
else:
12161224
col_values = self.explainer.X[col][lambda x: x != self.explainer.na_fill]
1217-
if is_bool_dtype(col_values):
1225+
if col in self.feature_input_ranges:
1226+
min_range, max_range = self.feature_input_ranges[col]
1227+
elif is_bool_dtype(col_values):
12181228
min_range = int(col_values.min())
12191229
max_range = int(col_values.max())
12201230
else:
1221-
min_range = np.round(col_values.min(), 2)
1222-
max_range = np.round(col_values.max(), 2)
1231+
min_range = np.round(col_values.min(), self.round)
1232+
max_range = np.round(col_values.max(), self.round)
1233+
1234+
if is_bool_dtype(col_values) or pd.api.types.is_integer_dtype(col_values):
1235+
step = 1
1236+
else:
1237+
step = 10 ** (-self.round)
12231238
return html.Div(
12241239
[
12251240
dbc.Label(col),
12261241
dbc.Input(
1227-
id="feature-input-" + col + "-input-" + self.name, type="number"
1242+
id="feature-input-" + col + "-input-" + self.name,
1243+
type="number",
1244+
min=min_range,
1245+
max=max_range,
1246+
step=step,
12281247
),
12291248
dbc.FormText(f"Range: {min_range}-{max_range}")
12301249
if not self.hide_range

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "explainerdashboard"
7-
version = "0.5.6"
7+
version = "0.5.7"
88
description = "Quickly build Explainable AI dashboards that show the inner workings of so-called \"blackbox\" machine learning models."
99
readme = "README.md"
1010
requires-python = ">=3.10"

tests/test_feature_input_component.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,30 @@ def test_feature_input_component_handles_bool_columns(classifier_data):
2424

2525
layout = component.layout()
2626
assert layout is not None
27+
28+
29+
def test_feature_input_component_respects_custom_range_and_rounding(classifier_data):
30+
X_train, y_train, X_test, y_test = classifier_data
31+
32+
model = RandomForestClassifier(n_estimators=5, max_depth=2)
33+
model.fit(X_train, y_train)
34+
35+
explainer = ClassifierExplainer(model, X_test, y_test)
36+
component = FeatureInputComponent(
37+
explainer, feature_input_ranges={"Age": (0, 50)}, round=1
38+
)
39+
40+
age_div = next(
41+
div
42+
for div in component._feature_inputs
43+
if getattr(div.children[0], "children", None) == "Age"
44+
)
45+
age_input = age_div.children[1]
46+
range_text = age_div.children[2].children
47+
48+
props = age_input.to_plotly_json()["props"]
49+
50+
assert props.get("min") == 0
51+
assert props.get("max") == 50
52+
assert props.get("step") == 0.1
53+
assert range_text == "Range: 0-50"

0 commit comments

Comments
 (0)