Skip to content

Commit 25da586

Browse files
Merge pull request #14 from artefactory/dev_paper_submission
Add last modifications on apper before submission
2 parents 6147623 + 560e7d3 commit 25da586

File tree

7 files changed

+118
-68
lines changed

7 files changed

+118
-68
lines changed

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,7 @@ pip install woodtapper
4242

4343
**From this repository, within a pip/conda/mamba environment (python=3.12)**:
4444
```bash
45-
pip install -r requirements.txt
46-
pip install -e '.[dev]'
45+
pip install -e .[dev,docs]
4746
```
4847

4948
## 🌿 WoodTapper RulesExtraction module

docs/2_tutorials_example_exp.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,5 +54,5 @@ RF.fit(X_train, y_train)
5454

5555
# Load an existing RandomForestClassifier into the explainer
5656
RFExplained = RandomForestClassifierExplained.load_forest(RandomForestClassifierExplained,RF,X_train,y_train)
57-
X_explain, y_explain = RFExplained.explanation(X_test)
57+
Xy_explain = RFExplained.explanation(X_test)
5858
```

docs/installation.md

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@ git clone https://github.com/artefactory/woodtapper.git
1616
```
1717
And install the required packages into your environment (conda, mamba or pip):
1818
```bash
19-
pip install -r requirements.txt
20-
```
21-
Then run the following command from the repository root directory :
22-
```
23-
pip install -e .[dev]
19+
pip install -e .[dev,docs]
2420
```
2521

2622
## Dependencies
2.39 KB
Binary file not shown.

paper/paper.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
---
2-
title: 'WoodTapper: a Python package for tapping decision tree ensembles'
2+
title: 'WoodTapper: a Python package for explaining decision tree ensembles'
33
tags:
44
- Python
55
- Machine Learning
@@ -55,28 +55,28 @@ In a tree $\mathcal{T}$, we denote the path of successive splits from the root n
5555
$$
5656
\mathcal{P} = \{(j_k,r_k,s_k), k=1, \dots, K\},
5757
$$
58-
where $K$ is the path length, $j_k$ is the selected feature at depth $k$, $r_k$ the selected splitting position along $X^{(j_k)}$ and $s_k$ the corresponding sign (either $\leq$ corresponding to the left node or $>$ corresponding to the right node).
59-
Thus, each path defines a hyperrectangle in the input space, denoted $\hat{H}(\mathcal{P}) \subset \mathbb{R}^p$. Hence, each path can be associated with a rule function $\hat{g}_{\mathcal{D},\mathcal{P}}$, that returns the mean of $Y$ from the training sample inside and outside of $\hat{H}(\mathcal{P})$:
58+
where $K$ is the path length, $j_k \in \{1, \dots,p\}$ is the selected feature at depth $k$, $r_k \in \mathbb{R}$ the selected splitting position along $x^{(j_k)}$ and $s_k$ the corresponding sign (either $\leq$ corresponding to the left node or $>$ corresponding to the right node).
59+
Thus, each path defines a hyperrectangle in the input space, denoted $\hat{H}(\mathcal{P}) \subset \mathbb{R}^p$. Hence, each path can be associated with a rule function $\hat{g}_{\mathcal{P}}$, that returns the mean of $Y$ from the training sample inside and outside of $\hat{H}(\mathcal{P})$:
6060
$$
6161
\hat{g}_{\mathcal{P}}(x) =
6262
\begin{cases}
6363
\frac{\sum_{i=1}^{n}y_i \mathbb{I}_{\{x_i \in \hat{H}(\mathcal{P})\}}}{\sum_{i=1}^{n} \mathbb{I}_{\{x_i \in \hat{H}(\mathcal{P})\}}} \text{ if } x \in \hat{H}(\mathcal{P})\\
6464
\frac{\sum_{i=1}^{n}y_i \mathbb{I}_{\{x_i \not\in \hat{H}(\mathcal{P})\}}}{\sum_{i=1}^{n} \mathbb{I}_{\{x_i \not\in \hat{H}(\mathcal{P})\}}} \text{ otherwise }.
6565
\end{cases}
6666
$$
67-
We suppose we have a set of trees $\{\mathcal{T}_m, m=1, \dots, M \}$ from a random forest, each grown with randomness $\Theta_m$. For a path $\mathcal{P}$, we estimate the rule probability $p\left(\mathcal{P}\right)$ via Monte-Carlo sampling with $\hat{p}$,
67+
We suppose we have a set of trees $\{\mathcal{T}_m, m=1, \dots, M \}$ from a tree ensemble procedure, each grown with randomness $\Theta_m$. We denote by $\Pi$ the set of all possibles paths from $\{\mathcal{T}_m, m=1, \dots, M \}$. For a path $\mathcal{P} \in \Pi$, we estimate the rule probability $p\left(\mathcal{P}\right)$ via Monte-Carlo sampling with $\hat{p}\left(\mathcal{P}\right)$:
6868
$$
69-
\hat{p}_{}\left(\mathcal{P}\right) = \frac{1}{M} \sum_{m=1}^{M} \mathbb{1}_{\{\mathcal{P} \in \mathcal{T}(\Theta_m,\mathcal{D}_n)\}},
69+
\hat{p}\left(\mathcal{P}\right) = \frac{1}{M} \sum_{m=1}^{M} \mathbb{1}_{\{\mathcal{P} \in \mathcal{T}(\Theta_m,\mathcal{D}_n)\}},
7070
$$
71-
which corresponds to the probability that the path $\mathcal{P}$ belongs to the set of trees $\{\mathcal{T}_m, m=1, \dots, M \}$. We denote by $\Pi$ the set of extracted rules from $\{\mathcal{T}_m, m=1, \dots, M \}$.
71+
which corresponds to the empirical probability that the path $\mathcal{P} \in \Pi$ belongs to the set of trees $\{\mathcal{T}_m, m=1, \dots, M \}$.
7272

7373
The set of final rules is $\{\hat{g}_{\mathcal{P}}, \mathcal{P} \in \hat{\mathcal{P}}_{p_0}\}$ where $\hat{\mathcal{P}}_{p_0} = \left\{ \mathcal{P} \in \Pi, \, \hat{p}(\mathcal{P}) > p_0\right\}$ with $p_0 \in [0,1)$. The finals rules are aggregated as follows for building the final estimator:
7474
$$
7575
\hat{\eta}_{p_0}(x) = \frac{1}{|\hat{\mathcal{P}}_{p_0}|} \sum_{\mathcal{P} \in \hat{\mathcal{P}}_{p_0}} \hat{g}_{\mathcal{P}}(x).
7676
$$
7777

7878
So far, we have focused on binary classification for clarity.
79-
We also implemented SIRUS for regression, where final rules are aggregated using weights learned via ridge regression. Our implementation extends SIRUS to multiclass classification (not available in the original R version) as well as regression. It also leverages scikit-learn's implementations for tree-based models fitting.
79+
We also implemented the rule extractor for regression, where final rules are aggregated using weights learned via ridge regression. Our implementation extends SIRUS, i.e. rules extracted from random forest, to multiclass classification (not available in the original R version). Finally, our implementation also leverages scikit-learn's implementations for tree-based models fitting.
8080

8181
## Implementation and running time
8282
WoodTapper adheres to the scikit-learn [@pedregosa2011scikit] estimator interface, providing familiar methods such as $fit$, $predict$, and $get\_params$. This design enables smooth integration with existing workflows involving pipelines, cross-validation, and model selection (see Table \ref{tab:comparison}).
@@ -126,7 +126,7 @@ We compare the rules produced by the original SIRUS (R) and our Python implement
126126
## Formulation
127127

128128
The $\texttt{ExampleExplanation}$ module of WoodTapper is independent of the rule extraction module and provides an example-based explainability.
129-
It enables tree-based models to identify the $l \in \mathbb{N}$ most similar training samples to $x$, using the similarity measure induced by random forests [@breiman2001random;@grf].
129+
It enables tree-based models to identify the $l \in \mathbb{N}$ most similar training samples to $x$, using the similarity measure induced by generalized random forests [@breiman2001random;@grf].
130130
For a new sample $x$ with unknown label and $\mathcal{T}_m$ a decision tree, let $\mathcal{L}_m(x)$ denote the set of training samples that share the same leaf as $x$ in tree $\mathcal{T}_m$ for $m = 1, \dots, M$.
131131
Letting $w(x,x_i)$ be the similarity between $x$ and $x_i$, we have
132132
$$
@@ -138,7 +138,7 @@ Finally, the $l$ training samples with the highest $w(x,x_i)$ values, along with
138138
The $\textit{skgrf}$ [@skgrf] package is an interface for using the R implementation of generalized random forest in Python. $\textit{skgrf}$ has a specific number of classifiers for specifics learning tasks (causal inference, quantile regression,...). For each task, the user can compute the kernel weights, which are equivalent to our leaf frequency match introduce above. Thus, we aim at comparing the kernel weights derivation from $\textit{skgrf}$ to our $\texttt{ExampleExplanation}$ module. We stress on the fact that our $\texttt{ExampleExplanation}$ is designed for usual tree-based models such as random forest of extra trees and not specifically in a context of causal inference or quantile regression. Thus, the tree building (splitting criterion) of our forest are different from the ones from $\textit{skgrf}$.
139139

140140
## Implementation and running time
141-
As for SIRUS, our Python implementation of $\texttt{ExampleExplanation}$ adheres to the scikit-learn interface. Our $\texttt{ExampleExplanation}$ module is agnostic to the underlying tree ensemble, and can be used with random forests or extra trees (\ref{tab:comparison-grf}). For each ensemble type, a subclass inherits both the original scikit-learn class and our implemented class. The standard $\texttt{fit}$ and $\texttt{predict}$ methods remain unchanged, while an additional $\texttt{explain}$ method provides example-based explanations for new samples. This allows users to train and predict using standard scikit-learn workflows, while enabling access to $\texttt{ExampleExplanation}$ for interpretability analyses. We also have imlemented a method to load an already trained tree-basedd model into an $\texttt{ExampleExplanation}$ classifier.
141+
As for SIRUS, our Python implementation of $\texttt{ExampleExplanation}$ adheres to the scikit-learn interface. Our $\texttt{ExampleExplanation}$ module is agnostic to the underlying tree ensemble, and can be used with random forests or extra trees (\ref{tab:comparison-grf}). The standard $\texttt{fit}$ and $\texttt{predict}$ methods remain unchanged, while an additional $\texttt{explain}$ method provides example-based explanations for new samples. This allows users to train and predict using standard scikit-learn workflows, while enabling access to $\texttt{ExampleExplanation}$ for interpretability analyses. We also have implemented a method to load an already trained tree-based model into an $\texttt{ExampleExplanation}$ classifier.
142142

143143
: **Comparison of GRF weight computations in several Python packages.**\label{tab:comparison-grf}
144144

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ authors = [
88
{name = "Abdoulaye SAKHO", email = "abdoulaye7020@gmail.com"},
99
{name = "artefactory", email = "abdoulaye.sakho@artefact.com"},
1010
]
11-
version = "0.0.11"
11+
version = "0.0.12"
1212
description = "A Python toolbox for interpretable and explainable tree ensembles."
1313
readme = "README.md"
1414
license = "MIT"

woodtapper/extract_rules/visualization.py

Lines changed: 105 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,34 +8,38 @@
88

99

1010
def show_rules(
11-
RulesExtractorModel, max_rules=9, target_class_index=1, is_regression=False
11+
RulesExtractorModel,
12+
max_rules=9,
13+
target_class_index=1,
14+
is_regression=False,
15+
value_mappings=None,
1216
):
1317
"""
14-
Display the rules in a structured format, showing the conditions and associated probabilities for a specified target class.
18+
Display the rules in a structured format.
19+
1520
Parameters
1621
----------
1722
RulesExtractorModel : object
18-
The fitted rules extraction model containing the rules and probabilities.
19-
max_rules : int, optional (default=9)
20-
The maximum number of rules to display.
21-
target_class_index : int, optional (default=1)
22-
The index of the target class for which to display probabilities.
23-
list_indices_features_bin : list of int, optional (default=None)
24-
List of feature indices that are binary (0/1) for special formatting.
25-
Returns
26-
----------
27-
None
28-
1. Validate the presence of necessary attributes in the model.
29-
2. Extract rules and their associated probabilities.
30-
3. Format and display the rules in a tabular format.
31-
4. Include estimated average rates for the specified target class.
32-
5. Handle feature names for better readability, using provided mappings if available.
33-
6. Adjust formatting for binary features if specified.
34-
7. Ensure that the display is clear and informative, with appropriate headers and alignment.
35-
8. If the model lacks the required attributes, print an error message and exit.
36-
9. If there are no rules to display, print a corresponding message and exit.
37-
10. Calculate and display the estimated average probability for the target class based on 'else' clauses.
38-
11. Print the rules along with their conditions, 'then' probabilities, and 'else' probabilities in a structured table.
23+
Fitted rules extraction model.
24+
max_rules : int, default=9
25+
Max number of rules to display.
26+
target_class_index : int, default=1
27+
Class index whose probability to show (classification).
28+
is_regression : bool, default=False
29+
Switch to regression formatting.
30+
value_mappings : dict, optional
31+
{
32+
<feature_index or feature_name>: {
33+
<raw_value>: <display_string>,
34+
...
35+
},
36+
...
37+
}
38+
For binary features with both 0 and 1 mapped, rules become:
39+
FeatureName is <mapped_1> (if sign_internal == "R")
40+
FeatureName is <mapped_0> (if sign_internal == "L")
41+
(Instead of using negations.)
42+
3943
"""
4044
if (
4145
not hasattr(RulesExtractorModel, "rules_")
@@ -51,7 +55,10 @@ def show_rules(
5155
raise ValueError(
5256
"For regression, model must have 'list_probas_by_rules_without_coefficients' attribute."
5357
)
54-
list_indices_features_bin = RulesExtractorModel._list_categorical_indexes
58+
59+
list_indices_features_bin = getattr(
60+
RulesExtractorModel, "_list_categorical_indexes", None
61+
)
5562

5663
rules_all = RulesExtractorModel.rules_
5764
if is_regression:
@@ -76,27 +83,19 @@ def show_rules(
7683
"No rules to display. try to increase the number of rules extracted or check model fitting."
7784
)
7885

79-
# Attempt to build/use feature mapping
86+
# Feature name mapping
8087
feature_mapping = None
81-
if hasattr(
82-
RulesExtractorModel, "feature_names_in_"
83-
): # Standard scikit-learn attribute
84-
# Create a mapping from index to name if feature_names_in_ is a list
88+
if hasattr(RulesExtractorModel, "feature_names_in_"):
8589
feature_mapping = {
8690
i: name for i, name in enumerate(RulesExtractorModel.feature_names_in_)
8791
}
88-
elif hasattr(
89-
RulesExtractorModel, "feature_names_"
90-
): # Custom attribute for feature names
92+
elif hasattr(RulesExtractorModel, "feature_names_"):
9193
if isinstance(RulesExtractorModel.feature_names_, dict):
92-
feature_mapping = (
93-
RulesExtractorModel.feature_names_
94-
) # Assumes it's already index:name
94+
feature_mapping = RulesExtractorModel.feature_names_
9595
elif isinstance(RulesExtractorModel.feature_names_, list):
9696
feature_mapping = {
9797
i: name for i, name in enumerate(RulesExtractorModel.feature_names_)
9898
}
99-
# If no mapping, column_name will default to using indices.
10099

101100
base_ps_text = ""
102101
if not is_regression:
@@ -125,6 +124,51 @@ def show_rules(
125124
max_condition_len = 0
126125
condition_strings_for_rules = []
127126

127+
def _map_value(dim, dim_name, raw_val):
128+
if value_mappings is None:
129+
return None
130+
candidates = [dim]
131+
if dim_name is not None:
132+
candidates.append(dim_name)
133+
for c in candidates:
134+
if c in value_mappings:
135+
nested = value_mappings[c]
136+
if raw_val in nested:
137+
return nested[raw_val]
138+
if isinstance(raw_val, (float, np.floating)) and int(raw_val) in nested:
139+
return nested[int(raw_val)]
140+
return None
141+
142+
def _format_binary_condition(dimension, column_name, sign_internal):
143+
# Determine which side of binary (0 or 1) the rule represents.
144+
positive_val = 1
145+
negative_val = 0
146+
# Try to map both
147+
mapped_pos = _map_value(dimension, column_name, positive_val)
148+
mapped_neg = _map_value(dimension, column_name, negative_val)
149+
150+
# If both mapped, choose directly
151+
if mapped_pos is not None and mapped_neg is not None:
152+
if sign_internal == "R": # >
153+
return f"{column_name} is {mapped_pos}"
154+
else: # "<=" side
155+
return f"{column_name} is {mapped_neg}"
156+
# If only one mapped
157+
if mapped_pos is not None:
158+
if sign_internal == "R":
159+
return f"{column_name} is {mapped_pos}"
160+
else:
161+
return f"{column_name} is not {mapped_pos}"
162+
if mapped_neg is not None:
163+
if sign_internal == "L":
164+
return f"{column_name} is {mapped_neg}"
165+
else:
166+
return f"{column_name} is not {mapped_neg}"
167+
168+
# Fallback numeric
169+
raw_indicator = 0 if sign_internal == "L" else 1
170+
return f"{column_name} is {raw_indicator}"
171+
128172
for i in range(num_rules_to_show):
129173
current_rule_conditions = rules_all[i]
130174
condition_parts_str = []
@@ -133,31 +177,44 @@ def show_rules(
133177
rule=current_rule_conditions[j]
134178
)
135179

136-
column_name = f"Feature[{dimension}]" # Default if no mapping
180+
column_name = f"Feature[{dimension}]"
137181
if feature_mapping and dimension in feature_mapping:
138182
column_name = feature_mapping[dimension]
139183
elif (
140184
feature_mapping
141185
and isinstance(dimension, str)
142186
and dimension in feature_mapping.values()
143187
):
144-
# If dimension is already a name that's in the mapping's values (less common for index)
145188
column_name = dimension
146-
if (
189+
190+
is_binary = (
147191
list_indices_features_bin is not None
148192
and dimension in list_indices_features_bin
149-
):
150-
sign_display = "is" # if sign_internal == "L" else "is not"
151-
# treshold_display = str(treshold)
152-
treshold_display = str(0) if sign_internal == "L" else str(1)
193+
)
194+
195+
if is_binary:
196+
condition_parts_str.append(
197+
_format_binary_condition(dimension, column_name, sign_internal)
198+
)
153199
else:
154200
sign_display = "<=" if sign_internal == "L" else ">"
201+
if isinstance(treshold, float):
202+
treshold_display_raw = float(f"{treshold:.2f}")
203+
else:
204+
treshold_display_raw = treshold
205+
mapped = _map_value(dimension, column_name, treshold_display_raw)
155206
treshold_display = (
156-
f"{treshold:.2f}" if isinstance(treshold, float) else str(treshold)
207+
mapped
208+
if mapped is not None
209+
else (
210+
f"{treshold:.2f}"
211+
if isinstance(treshold, float)
212+
else str(treshold)
213+
)
214+
)
215+
condition_parts_str.append(
216+
f"{column_name} {sign_display} {treshold_display}"
157217
)
158-
condition_parts_str.append(
159-
f"{column_name} {sign_display} {treshold_display}"
160-
)
161218

162219
full_condition_str = " & ".join(condition_parts_str)
163220
condition_strings_for_rules.append(full_condition_str)
@@ -187,12 +244,10 @@ def show_rules(
187244
then_val_str = f"{p_s_if_true:.2f}"
188245
p_s_if_false = prob_if_false_list
189246
else_val_str = f"{p_s_if_false:.2f} | coeff={coefficients_all[i]:.2f}"
190-
191-
else: # classification
247+
else:
192248
if prob_if_true_list and len(prob_if_true_list) > target_class_index:
193249
p_s_if_true = prob_if_true_list[target_class_index] * 100
194250
then_val_str = f"{p_s_if_true:.0f}%"
195-
196251
if prob_if_false_list and len(prob_if_false_list) > target_class_index:
197252
p_s_if_false = prob_if_false_list[target_class_index] * 100
198253
else_val_str = f"{p_s_if_false:.0f}%"

0 commit comments

Comments
 (0)