Skip to content

Commit f1cf295

Browse files
authored
[API 2] add PDP (#517)
* minimal PDP * docstring * plot at end of cells * remove print, try fix coverage * add API page * fix docstring * typo * remove left right * fix plot * clean example
1 parent db3ccc6 commit f1cf295

File tree

4 files changed

+232
-0
lines changed

4 files changed

+232
-0
lines changed

docs/src/api.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ Feature Importance functions
4545
ensemble_clustered_inference
4646
ensemble_clustered_inference_pvalue
4747

48+
Visualization
49+
=============
50+
51+
.. autosummary::
52+
:toctree: ./generated/api/class/
53+
:template: class.rst
54+
55+
~visualization.PDP
56+
57+
4858
Samplers
4959
========
5060

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""
2+
Visualization with Partial Dependency Plots
3+
===========================================
4+
5+
This example demonstrates how to create Partial Dependency Plots (PDPs). This
6+
visualization method allows you to examine a model's dependence on a single feature or
7+
a pair of features. The underlying implementation is built upon
8+
sklearn.inspection.partial_dependence, which calculates the dependence by taking the
9+
average response of an estimator across all possible values of the target feature(s).
10+
We'll use the circles dataset to illustrate the basic usage.
11+
"""
12+
13+
# %%
14+
# Loading the circles dataset
15+
# ----------------------------
16+
# We start by sampling a synthetic dataset using the `make_circles` function from
17+
# `sklearn.datasets`.
18+
19+
import matplotlib.pyplot as plt
20+
import numpy as np
21+
import seaborn as sns
22+
from sklearn.datasets import make_circles
23+
24+
X, y = make_circles(n_samples=500, noise=0.1, factor=0.7, random_state=0)
25+
26+
# Visualizing the dataset
27+
_, ax = plt.subplots()
28+
sns.scatterplot(x=X[:, 0], y=X[:, 1], hue=y, ax=ax)
29+
ax.set_xlabel("X0")
30+
ax.set_ylabel("X1")
31+
sns.despine(ax=ax)
32+
c1 = plt.Circle((0, 0), 0.85, color="k", ls="--", fill=False, label="class boundary")
33+
ax.add_patch(c1)
34+
_ = ax.legend(loc="upper right")
35+
36+
37+
# %%
38+
# Training a classifier
39+
# ---------------------
40+
# Next, we train a model to solve the binary classification task presented by the
41+
# non-linearly separable circles dataset. For this example, we'll use a gradient
42+
# boosted tree ensemble, specifically the HistGradientBoostingClassifier from
43+
# scikit-learn.
44+
45+
from sklearn.ensemble import HistGradientBoostingClassifier
46+
from sklearn.metrics import roc_auc_score
47+
from sklearn.model_selection import train_test_split
48+
49+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
50+
model = HistGradientBoostingClassifier(random_state=0)
51+
52+
model.fit(X_train, y_train)
53+
y_pred = model.predict_proba(X_test)
54+
55+
auc = roc_auc_score(y_true=y_test, y_score=y_pred[:, 1])
56+
print(f"ROC AUC on the test set: {auc:.2f}")
57+
58+
59+
# %%
60+
# Partial Dependence for an Individual Feature
61+
# --------------------------------------------
62+
# Once the model is fitted, we use the Partial Dependency Plot (PDP) to visualize its
63+
# dependence on a single input feature (e.g., the first feature, :math:`X_0`).The
64+
# resulting plot shows the average response of the model (on the :math`y`-axis)
65+
# for each possible value of the selected feature (on the :math:`x`-axis), with the averaging
66+
# performed over all other features in the dataset.
67+
#
68+
# The plot also includes the marginal distribution of the feature considered along
69+
# the :math:`x`-axis. This feature distribution is essential for identifying
70+
# low-density regions in the data. Model predictions and the estimated partial
71+
# dependence can be less reliable or extrapolated in these regions.
72+
73+
74+
from hidimstat.visualization import PDP
75+
76+
# sphinx_gallery_thumbnail_number = 2
77+
pdp = PDP(model)
78+
_ = pdp.plot(X_test, features=0)
79+
80+
81+
# %%
82+
# Partial Dependence on a Pair of Features
83+
# ----------------------------------------
84+
# We can similarly visualize the dependence of the model on a pair of features
85+
# (e.g., :math:`X_0` and :math:`X_1`). Here, the partial dependence is encoded by
86+
# contour lines (level lines) across the 2D plot. The marginal distribution for each
87+
# feature is also represented along the axes to help identify regions where the
88+
# estimated dependence might be unreliable due to a low density of training data.
89+
90+
axes = pdp.plot(X_test, features=[0, 1], cmap="RdBu_r")
91+
c1 = plt.Circle((0, 0), 0.85, color="k", ls="--", fill=False, zorder=10)
92+
_ = axes[1, 0].add_patch(c1)
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .partial_dependence_plot import PDP
2+
3+
__all__ = ["PDP"]
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
from copy import copy
2+
3+
import matplotlib.pyplot as plt
4+
import numpy as np
5+
import seaborn as sns
6+
from sklearn.inspection import partial_dependence
7+
8+
9+
class PDP:
10+
"""
11+
Partial Dependence Plot (PDP) visualization. This class is based on
12+
`sklearn.inspection.partial_dependence` to compute the partial dependence
13+
values and provides methods to plot 1D and 2D PDPs. For each realization of a feature
14+
or pair of features :math:`x_S`, the partial dependence :math:`f_S(x_S)` is defined
15+
as :math:`f_S(x_S) = \mathbb{E}_{X_{-S}}[ f(x_S, X_{-S})]`,
16+
where :math:`X_{-S}` denotes all features except those in :math:`S`.
17+
18+
Parameters
19+
----------
20+
estimator : object
21+
A fitted scikit-learn estimator implementing `predict` or `predict_proba`.
22+
feature_names : list of str, optional
23+
Names of the features. If None, X0, X1, ... will be used.
24+
25+
"""
26+
27+
def __init__(self, estimator, feature_names=None):
28+
self.estimator = estimator
29+
self.feature_names = feature_names
30+
31+
def plot(self, X, features, cmap="viridis", **kwargs):
32+
"""
33+
Plot the Partial Dependence Plot for the specified feature (1D) or pair of
34+
features (2D). The marginal distribution of the feature(s) is also displayed.
35+
36+
Parameters
37+
----------
38+
X : array-like of shape (n_samples, n_features)
39+
The input data used to compute the partial dependence.
40+
features : int or list of int
41+
The feature index (for 1D PDP) or list of two feature indices (for 2D PDP).
42+
cmap : str, optional
43+
The colormap to use for the plot (only for 2D PDP). Default is "viridis".
44+
**kwargs : additional keyword arguments
45+
Additional keyword arguments passed to:
46+
- `sns.lineplot` for 1D PDP
47+
- `ax.contour` for 2D PDP
48+
"""
49+
if isinstance(features, int):
50+
feature_ids = [features]
51+
plotting_func = self._plot_1d
52+
elif isinstance(features, list):
53+
if len(features) > 2:
54+
raise ValueError("Only 1D and 2D PDP plots are supported")
55+
else:
56+
feature_ids = copy(features)
57+
plotting_func = self._plot_2d
58+
59+
if self.feature_names is not None:
60+
feature_names = [self.feature_names[idx] for idx in feature_ids]
61+
else:
62+
feature_names = [f"X{idx}" for idx in feature_ids]
63+
64+
pd = partial_dependence(self.estimator, X, features=features)
65+
return plotting_func(pd, feature_names, cmap=cmap, **kwargs)
66+
67+
@staticmethod
68+
def _plot_1d(pd, feature_names, cmap=None, **kwargs):
69+
70+
_, axes = plt.subplots(2, 1, height_ratios=[0.2, 1])
71+
ax = axes[0]
72+
73+
sns.kdeplot(pd["grid_values"], ax=ax, legend=False, fill=True)
74+
sns.rugplot(pd["grid_values"], ax=ax, height=0.25, legend=False)
75+
sns.despine(ax=ax, left=True)
76+
# Plot partial dependence
77+
ax.spines["left"].set_visible(False)
78+
ax.spines["bottom"].set_visible(True)
79+
ax.xaxis.set_ticks([])
80+
ax.yaxis.set_visible(False)
81+
82+
ax = axes[1]
83+
sns.lineplot(x=pd["grid_values"][0], y=pd["average"][0], **kwargs)
84+
ax.set_xlabel(feature_names[0])
85+
ax.set_ylabel("Partial Dependence")
86+
sns.despine(ax=ax)
87+
plt.tight_layout()
88+
return axes
89+
90+
@staticmethod
91+
def _plot_2d(pd, feature_names, cmap="viridis", **kwargs):
92+
x = pd["grid_values"][0]
93+
y = pd["grid_values"][1]
94+
z = pd["average"][0]
95+
96+
xx, yy = np.meshgrid(x, y, indexing="ij")
97+
98+
_, axes = plt.subplots(
99+
2, 2, figsize=(8, 6), height_ratios=[0.2, 1], width_ratios=[1, 0.2]
100+
)
101+
ax = axes[1, 0]
102+
contour = ax.contour(xx, yy, z, cmap=cmap, **kwargs)
103+
ax.set_xlabel(feature_names[0])
104+
ax.set_ylabel(feature_names[1])
105+
ax.clabel(contour, inline=True, fontsize=10)
106+
sns.despine(ax=ax)
107+
108+
ax = axes[0, 0]
109+
sns.kdeplot(x, ax=ax, legend=False, fill=True)
110+
sns.rugplot(x, ax=ax, height=0.25, legend=False)
111+
sns.despine(ax=ax)
112+
ax.spines["left"].set_visible(False)
113+
ax.spines["bottom"].set_visible(True)
114+
ax.xaxis.set_ticks([])
115+
ax.yaxis.set_visible(False)
116+
117+
ax = axes[1, 1]
118+
sns.kdeplot(y=y, ax=ax, legend=False, fill=True)
119+
sns.rugplot(y=y, ax=ax, height=0.25, legend=False)
120+
sns.despine(ax=ax)
121+
ax.spines["bottom"].set_visible(False)
122+
ax.yaxis.set_ticks([])
123+
ax.xaxis.set_visible(False)
124+
125+
axes[0, 1].remove()
126+
plt.tight_layout()
127+
return axes

0 commit comments

Comments
 (0)