|
| 1 | +from copy import copy |
| 2 | + |
| 3 | +import matplotlib.pyplot as plt |
| 4 | +import numpy as np |
| 5 | +import seaborn as sns |
| 6 | +from sklearn.inspection import partial_dependence |
| 7 | + |
| 8 | + |
| 9 | +class PDP: |
| 10 | + """ |
| 11 | + Partial Dependence Plot (PDP) visualization. This class is based on |
| 12 | + `sklearn.inspection.partial_dependence` to compute the partial dependence |
| 13 | + values and provides methods to plot 1D and 2D PDPs. For each realization of a feature |
| 14 | + or pair of features :math:`x_S`, the partial dependence :math:`f_S(x_S)` is defined |
| 15 | + as :math:`f_S(x_S) = \mathbb{E}_{X_{-S}}[ f(x_S, X_{-S})]`, |
| 16 | + where :math:`X_{-S}` denotes all features except those in :math:`S`. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + estimator : object |
| 21 | + A fitted scikit-learn estimator implementing `predict` or `predict_proba`. |
| 22 | + feature_names : list of str, optional |
| 23 | + Names of the features. If None, X0, X1, ... will be used. |
| 24 | +
|
| 25 | + """ |
| 26 | + |
| 27 | + def __init__(self, estimator, feature_names=None): |
| 28 | + self.estimator = estimator |
| 29 | + self.feature_names = feature_names |
| 30 | + |
| 31 | + def plot(self, X, features, cmap="viridis", **kwargs): |
| 32 | + """ |
| 33 | + Plot the Partial Dependence Plot for the specified feature (1D) or pair of |
| 34 | + features (2D). The marginal distribution of the feature(s) is also displayed. |
| 35 | +
|
| 36 | + Parameters |
| 37 | + ---------- |
| 38 | + X : array-like of shape (n_samples, n_features) |
| 39 | + The input data used to compute the partial dependence. |
| 40 | + features : int or list of int |
| 41 | + The feature index (for 1D PDP) or list of two feature indices (for 2D PDP). |
| 42 | + cmap : str, optional |
| 43 | + The colormap to use for the plot (only for 2D PDP). Default is "viridis". |
| 44 | + **kwargs : additional keyword arguments |
| 45 | + Additional keyword arguments passed to: |
| 46 | + - `sns.lineplot` for 1D PDP |
| 47 | + - `ax.contour` for 2D PDP |
| 48 | + """ |
| 49 | + if isinstance(features, int): |
| 50 | + feature_ids = [features] |
| 51 | + plotting_func = self._plot_1d |
| 52 | + elif isinstance(features, list): |
| 53 | + if len(features) > 2: |
| 54 | + raise ValueError("Only 1D and 2D PDP plots are supported") |
| 55 | + else: |
| 56 | + feature_ids = copy(features) |
| 57 | + plotting_func = self._plot_2d |
| 58 | + |
| 59 | + if self.feature_names is not None: |
| 60 | + feature_names = [self.feature_names[idx] for idx in feature_ids] |
| 61 | + else: |
| 62 | + feature_names = [f"X{idx}" for idx in feature_ids] |
| 63 | + |
| 64 | + pd = partial_dependence(self.estimator, X, features=features) |
| 65 | + return plotting_func(pd, feature_names, cmap=cmap, **kwargs) |
| 66 | + |
| 67 | + @staticmethod |
| 68 | + def _plot_1d(pd, feature_names, cmap=None, **kwargs): |
| 69 | + |
| 70 | + _, axes = plt.subplots(2, 1, height_ratios=[0.2, 1]) |
| 71 | + ax = axes[0] |
| 72 | + |
| 73 | + sns.kdeplot(pd["grid_values"], ax=ax, legend=False, fill=True) |
| 74 | + sns.rugplot(pd["grid_values"], ax=ax, height=0.25, legend=False) |
| 75 | + sns.despine(ax=ax, left=True) |
| 76 | + # Plot partial dependence |
| 77 | + ax.spines["left"].set_visible(False) |
| 78 | + ax.spines["bottom"].set_visible(True) |
| 79 | + ax.xaxis.set_ticks([]) |
| 80 | + ax.yaxis.set_visible(False) |
| 81 | + |
| 82 | + ax = axes[1] |
| 83 | + sns.lineplot(x=pd["grid_values"][0], y=pd["average"][0], **kwargs) |
| 84 | + ax.set_xlabel(feature_names[0]) |
| 85 | + ax.set_ylabel("Partial Dependence") |
| 86 | + sns.despine(ax=ax) |
| 87 | + plt.tight_layout() |
| 88 | + return axes |
| 89 | + |
| 90 | + @staticmethod |
| 91 | + def _plot_2d(pd, feature_names, cmap="viridis", **kwargs): |
| 92 | + x = pd["grid_values"][0] |
| 93 | + y = pd["grid_values"][1] |
| 94 | + z = pd["average"][0] |
| 95 | + |
| 96 | + xx, yy = np.meshgrid(x, y, indexing="ij") |
| 97 | + |
| 98 | + _, axes = plt.subplots( |
| 99 | + 2, 2, figsize=(8, 6), height_ratios=[0.2, 1], width_ratios=[1, 0.2] |
| 100 | + ) |
| 101 | + ax = axes[1, 0] |
| 102 | + contour = ax.contour(xx, yy, z, cmap=cmap, **kwargs) |
| 103 | + ax.set_xlabel(feature_names[0]) |
| 104 | + ax.set_ylabel(feature_names[1]) |
| 105 | + ax.clabel(contour, inline=True, fontsize=10) |
| 106 | + sns.despine(ax=ax) |
| 107 | + |
| 108 | + ax = axes[0, 0] |
| 109 | + sns.kdeplot(x, ax=ax, legend=False, fill=True) |
| 110 | + sns.rugplot(x, ax=ax, height=0.25, legend=False) |
| 111 | + sns.despine(ax=ax) |
| 112 | + ax.spines["left"].set_visible(False) |
| 113 | + ax.spines["bottom"].set_visible(True) |
| 114 | + ax.xaxis.set_ticks([]) |
| 115 | + ax.yaxis.set_visible(False) |
| 116 | + |
| 117 | + ax = axes[1, 1] |
| 118 | + sns.kdeplot(y=y, ax=ax, legend=False, fill=True) |
| 119 | + sns.rugplot(y=y, ax=ax, height=0.25, legend=False) |
| 120 | + sns.despine(ax=ax) |
| 121 | + ax.spines["bottom"].set_visible(False) |
| 122 | + ax.yaxis.set_ticks([]) |
| 123 | + ax.xaxis.set_visible(False) |
| 124 | + |
| 125 | + axes[0, 1].remove() |
| 126 | + plt.tight_layout() |
| 127 | + return axes |
0 commit comments