Skip to content

Commit 89d9913

Browse files
authored
Fix dtreeviz decisiontree_view input handling (#321)
1 parent 74ac0ff commit 89d9913

File tree

5 files changed

+37
-2
lines changed

5 files changed

+37
-2
lines changed

RELEASE_NOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
- **XGBoost 3.1+ compatibility**: Fixed handling of string-formatted predictions and `base_score` values returned by XGBoost 3.1+. Added robust string-to-numeric conversion with proper regex fallback to handle various string formats (e.g., `'[3.2967056E1]'`, `'[8.563135E-2,7.169811E-1,1.9738752E-1]'`)
2222
- **XGBoost SHAP initialization**: Fixed `base_score` conversion in both `get_params()` and booster's internal JSON configuration to ensure SHAP TreeExplainer initializes correctly with XGBoost 3.1+
2323
- **RandomForest dtreeviz compatibility**: Fixed dtype handling for `y_train` (now uses `int` instead of `int16`) and observation array conversion for `predict_path()` to work with newer dtreeviz versions
24+
- **Dtreeviz decisiontree_view**: Ensure observations are passed as numpy arrays to avoid pandas label lookup errors when dtreeviz indexes features by integer position
2425
- **Pandas deprecation warnings**: Removed deprecated `pd.option_context("future.no_silent_downcasting")` and `copy=False` parameter from `.infer_objects()` calls
2526
- **Runtime warnings**: Fixed divide-by-zero warnings in classification plots and residuals plots (log-ratio calculations) by adding proper zero checks and using `np.divide()` with `where` parameter
2627

explainerdashboard/explainers.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4869,8 +4869,13 @@ def decisiontree_view(self, tree_idx, index, show_just_path=False):
48694869

48704870
viz = DTreeVizAPI(self.shadow_trees[tree_idx])
48714871

4872+
x_row = self.get_X_row(index).squeeze()
4873+
if isinstance(x_row, pd.Series):
4874+
x = x_row.to_numpy()
4875+
else:
4876+
x = np.atleast_1d(np.asarray(x_row))
48724877
return viz.view(
4873-
x=self.get_X_row(index).squeeze(),
4878+
x=x,
48744879
fancy=False,
48754880
show_node_labels=False,
48764881
show_just_path=show_just_path,
@@ -5145,8 +5150,13 @@ def decisiontree_view(self, tree_idx, index, show_just_path=False, pos_label=Non
51455150

51465151
viz = DTreeVizAPI(self.shadow_trees[tree_idx])
51475152

5153+
x_row = self.get_X_row(index).squeeze()
5154+
if isinstance(x_row, pd.Series):
5155+
x = x_row.to_numpy()
5156+
else:
5157+
x = np.atleast_1d(np.asarray(x_row))
51485158
return viz.view(
5149-
x=self.get_X_row(index).squeeze(),
5159+
x=x,
51505160
fancy=False,
51515161
show_node_labels=False,
51525162
show_just_path=show_just_path,

tests/test_dtreeviz_contracts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import inspect
2+
3+
import dtreeviz.trees as trees
4+
5+
6+
def test_dtreeviz_view_signature_includes_x():
7+
sig = inspect.signature(trees.DTreeVizAPI.view)
8+
assert "x" in sig.parameters

tests/test_randomforest_explainer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ def test_rfclas_decisionpath_df(precalculated_rf_classifier_explainer, test_name
2424
assert isinstance(df, pd.DataFrame)
2525

2626

27+
def test_rfclas_decisiontree_view_contract(precalculated_rf_classifier_explainer):
28+
precalculated_rf_classifier_explainer._graphviz_available = True
29+
render = precalculated_rf_classifier_explainer.decisiontree_view(
30+
tree_idx=0, index=0
31+
)
32+
assert isinstance(render, dtreeviz.utils.DTreeVizRender)
33+
34+
2735
def test_rfclas_plot_trees(precalculated_rf_classifier_explainer, test_names):
2836
fig = precalculated_rf_classifier_explainer.plot_trees(index=0)
2937
assert isinstance(fig, go.Figure)

tests/test_xgboost_treeviz.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,14 @@ def test_xgbclas_decisionpath_df(precalculated_xgb_classifier_explainer, test_na
2727
assert isinstance(df, pd.DataFrame)
2828

2929

30+
def test_xgbclas_decisiontree_view_contract(precalculated_xgb_classifier_explainer):
31+
precalculated_xgb_classifier_explainer._graphviz_available = True
32+
render = precalculated_xgb_classifier_explainer.decisiontree_view(
33+
tree_idx=0, index=0
34+
)
35+
assert isinstance(render, dtreeviz.utils.DTreeVizRender)
36+
37+
3038
def test_xgbclas_plot_trees(precalculated_xgb_classifier_explainer, test_names):
3139
fig = precalculated_xgb_classifier_explainer.plot_trees(index=0)
3240
assert isinstance(fig, go.Figure)

0 commit comments

Comments
 (0)