Skip to content
45 changes: 45 additions & 0 deletions skore/src/skore/_sklearn/_cross_validation/inspection_accessor.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from __future__ import annotations

from collections.abc import Callable

import pandas as pd
from numpy.typing import ArrayLike
from sklearn.utils.metaestimators import available_if

from skore._externals._pandas_accessors import DirNamesMixin
from skore._sklearn._base import _BaseAccessor
from skore._sklearn._cross_validation.report import CrossValidationReport
from skore._sklearn._plot.inspection.coefficients import CoefficientsDisplay
from skore._sklearn._plot.inspection.impurity_decrease import ImpurityDecreaseDisplay
from skore._sklearn._plot.inspection.permutation_importance import (
PermutationImportanceDisplay,
)
from skore._sklearn.types import DataSource
from skore._utils._accessor import (
_check_cross_validation_sub_estimator_has_coef,
_check_cross_validation_sub_estimator_has_feature_importances,
)

Metric = str | Callable | list[str] | tuple[str] | dict[str, Callable] | None


class _InspectionAccessor(_BaseAccessor[CrossValidationReport], DirNamesMixin):
"""Accessor for model inspection related operations.
Expand Down Expand Up @@ -77,6 +87,41 @@ def coefficients(self) -> CoefficientsDisplay:
report_type="cross-validation",
)

def permutation_importance(
self,
*,
data_source: DataSource = "test",
X: ArrayLike | None = None,
y: ArrayLike | None = None,
at_step: int | str = 0,
metric: Metric = None,
n_repeats: int = 5,
max_samples: float = 1.0,
n_jobs: int | None = None,
seed: int | None = None,
) -> PermutationImportanceDisplay:
importances = []
for report_idx, report in enumerate(self._parent.estimator_reports_):
display = report.inspection.permutation_importance(
data_source=data_source,
X=X,
y=y,
at_step=at_step,
metric=metric,
n_repeats=n_repeats,
max_samples=max_samples,
n_jobs=n_jobs,
seed=seed,
)
df = display.importances
df["split"] = report_idx
importances.append(df)

importances = pd.concat(importances, axis="index")
return PermutationImportanceDisplay(
importances=importances, report_type="cross-validation"
)

@available_if(_check_cross_validation_sub_estimator_has_feature_importances())
def impurity_decrease(self) -> ImpurityDecreaseDisplay:
"""Retrieve the Mean Decrease in Impurity (MDI) across splits.
Expand Down
3 changes: 1 addition & 2 deletions skore/src/skore/_sklearn/_estimator/inspection_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import joblib
import numpy as np
import pandas as pd
from numpy.typing import ArrayLike
from scipy.sparse import issparse
from sklearn.pipeline import Pipeline
Expand Down Expand Up @@ -135,7 +134,7 @@ def permutation_importance(
max_samples: float = 1.0,
n_jobs: int | None = None,
seed: int | None = None,
) -> pd.DataFrame:
) -> PermutationImportanceDisplay:
"""Report the permutation feature importance.

This computes the permutation importance using sklearn's
Expand Down
173 changes: 82 additions & 91 deletions skore/src/skore/_sklearn/_plot/inspection/permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ class PermutationImportanceDisplay(DisplayMixin):
- `estimator`
- `data_source`
- `metric`
- `split`
- `feature`
- `label` or `output` (classification vs. regression)
- `repetition`
- `value`

report_type : {"estimator"}
report_type : {"estimator", "cross-validation"}
Report type from which the display is created.

Attributes
Expand Down Expand Up @@ -134,6 +135,7 @@ def _compute_data_for_display(
"estimator",
"data_source",
"metric",
"split",
"feature",
"label",
"output",
Expand All @@ -143,6 +145,7 @@ def _compute_data_for_display(
df_importances = pd.concat(df_importances, axis="index")
df_importances["data_source"] = data_source
df_importances["estimator"] = estimator_name
df_importances["split"] = np.nan

return PermutationImportanceDisplay(
importances=df_importances[ordered_columns], report_type=report_type
Expand All @@ -152,25 +155,28 @@ def _compute_data_for_display(
def _get_columns_to_groupby(*, frame: pd.DataFrame) -> list[str]:
"""Get the available columns from which to group by."""
columns_to_groupby = list[str]()
if "metric" in frame.columns and frame["metric"].nunique() > 1:
columns_to_groupby.append("metric")
if "label" in frame.columns and frame["label"].nunique() > 1:
columns_to_groupby.append("label")
if "output" in frame.columns and frame["output"].nunique() > 1:
columns_to_groupby.append("output")
if "split" in frame.columns and frame["split"].nunique() > 1:
columns_to_groupby.append("split")
return columns_to_groupby

@DisplayMixin.style_plot
def plot(
self,
*,
subplot_by: str | tuple[str, str] | None = "auto",
metric: str | list[str] | None = None,
metric: str,
subplot_by: str | tuple[str, ...] | None = "auto",
) -> None:
"""Plot the permutation importance.

Parameters
----------
metric : str
Metric to plot.

Comment on lines +221 to +223
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's give it a None default and raise an error asking to choose a metric when multiple metrics were passed to the constructor. Its counter intuitive to have to choose a metric when I did not specify any in the constructor.

subplot_by : str, tuple of str or None, default="auto"
Column(s) to use for subplotting. The possible values are:

Expand All @@ -179,40 +185,26 @@ def plot(
- if a string, the corresponding column of the dataframe is used to create
several subplots. Those plots will be a organized in a grid of a single
row and several columns.
- if a tuple of strings, the corresponding columns of the dataframe are used
to create several subplots. Those plots will be a organized in a grid of
several rows and columns. The first element of the tuple is the row and
the second element is the column.
- if a tuple of 2 strings, the corresponding columns are used to create
subplots in a grid. The first element is the row, the second is the
column.
- if `None`, all information is plotted on a single plot. An error is raised
if there is too much information to plot on a single plot.

metric : str or list of str, default=None
Filter the importances by metric. If `None`, all importances associated with
each metric are plotted.
"""
return self._plot(subplot_by=subplot_by, metric=metric)

def _plot_matplotlib(
self,
*,
subplot_by: str | tuple[str, str] | None = "auto",
metric: str | list[str] | None = None,
metric: str,
subplot_by: str | tuple[str, ...] | None = "auto",
) -> None:
"""Dispatch the plotting function for matplotlib backend."""
boxplot_kwargs = self._default_boxplot_kwargs.copy()
stripplot_kwargs = self._default_stripplot_kwargs.copy()
frame = self.frame(metric=metric, aggregate=None)

err_msg = (
"You try to plot the permutation importance of metrics averaged over {} "
"and other without averaging. This setting is not supported. Please filter "
"a group of consistent metrics using the `metric` parameter."
)
if "label" in frame.columns and frame["label"].isna().any():
raise ValueError(err_msg.format("labels"))
elif "output" in frame.columns and frame["output"].isna().any():
raise ValueError(err_msg.format("outputs"))

self._plot_single_estimator(
subplot_by=subplot_by,
frame=frame,
Expand All @@ -221,74 +213,79 @@ def _plot_matplotlib(
stripplot_kwargs=stripplot_kwargs,
)

@staticmethod
def _aggregate_over_split(*, frame: pd.DataFrame) -> pd.DataFrame:
"""Compute the averaged scores over the splits."""
group_by = frame.columns.difference(["repetition", "value"]).tolist()
return (
frame.drop(columns=["repetition"])
.groupby(group_by, sort=False, dropna=False)
.aggregate("mean")
.reset_index()
)

def _plot_single_estimator(
self,
*,
subplot_by: str | tuple[str, str] | None,
subplot_by: str | tuple[str, ...] | None,
frame: pd.DataFrame,
estimator_name: str,
boxplot_kwargs: dict[str, Any],
stripplot_kwargs: dict[str, Any],
) -> None:
"""Plot the permutation importance for an `EstimatorReport`."""
if subplot_by == "auto":
is_multi_metric = frame["metric"].nunique() > 1
is_multi_target = any(name in frame.columns for name in ["label", "output"])
if is_multi_metric and is_multi_target:
hue, col, row = (
"label" if "label" in frame.columns else "output",
"metric",
None,
)
elif is_multi_metric:
hue, col, row = None, "metric", None
elif is_multi_target:
hue, col, row = (
"label" if "label" in frame.columns else "output",
None,
None,
)
aggregate_title = ""
if subplot_by == "auto" or subplot_by is None:
columns_to_groupby = self._get_columns_to_groupby(frame=frame)

if "split" in columns_to_groupby:
frame = self._aggregate_over_split(frame=frame)
columns_to_groupby.remove("split")
aggregate_title = "averaged over splits"

if subplot_by is not None:
col = columns_to_groupby[0] if columns_to_groupby else None
hue, row = None, None
else:
hue, col, row = None, None, None
elif subplot_by is None:
# Possible accepted values: {"metric"}, {"label"}, {"output"}
hue = columns_to_groupby[0] if columns_to_groupby else None
col, row = None, None

else:
columns_to_groupby = self._get_columns_to_groupby(frame=frame)
n_columns_to_groupby = len(columns_to_groupby)
if n_columns_to_groupby > 1:
subplot_cols = (subplot_by,) if isinstance(subplot_by, str) else subplot_by
invalid = [c for c in subplot_cols if c not in columns_to_groupby]
if invalid:
raise ValueError(
"Cannot plot all the available information available on a single "
"plot. Please set `subplot_by` to a string or a tuple of strings. "
"You can use the following values to create subplots: "
f"The column(s) {invalid} are not available. You can use the "
"following values to create subplots: "
f"{', '.join(columns_to_groupby)}"
)
elif n_columns_to_groupby == 1:
hue, col, row = columns_to_groupby[0], None, None
else:
hue, col, row = None, None, None
else:
# Possible accepted values: {"metric"}, {"metric", "label"},
# {"metric", "output"}
columns_to_groupby = self._get_columns_to_groupby(frame=frame)
if isinstance(subplot_by, str):
if subplot_by not in columns_to_groupby:
raise ValueError(
f"The column {subplot_by} is not available. You can use the "
"following values to create subplots: "
f"{', '.join(columns_to_groupby)}"

remaining = set(columns_to_groupby) - set(subplot_cols)
if "split" in remaining:
frame = self._aggregate_over_split(frame=frame)
remaining.remove("split")
aggregate_title = "averaged over splits"

match len(subplot_cols):
case 1:
col, row, hue = (
subplot_cols[0],
None,
(next(iter(remaining)) if remaining else None),
)
col, row = subplot_by, None
if remaining_column := set(columns_to_groupby) - {subplot_by}:
hue = next(iter(remaining_column))
else:
hue = None
else:
if not all(item in columns_to_groupby for item in subplot_by):
case 2:
row, col, hue = (
subplot_cols[0],
subplot_cols[1],
(next(iter(remaining)) if remaining else None),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
(next(iter(remaining)) if remaining else None),
next(iter(remaining), None),

Comment on lines +315 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both row and hue are the same

)
case _:
raise ValueError(
f"The columns {subplot_by} are not available. You can use the "
"following values to create subplots: "
"Expected 1 to 2 columns for subplot_by, got "
f"{len(subplot_cols)}. You can use the following values: "
f"{', '.join(columns_to_groupby)}"
)
(row, col), hue = subplot_by, None

if hue is None:
# we don't need the palette and we are at risk of raising an error or
Expand Down Expand Up @@ -316,32 +313,23 @@ def _plot_single_estimator(
)
add_background_features = hue is not None

metrics = frame["metric"].unique()
metric_name = frame["metric"].unique()[0]
self.figure_, self.ax_ = self.facet_.figure, self.facet_.axes.squeeze()
for row_index, row_axes in enumerate(self.facet_.axes):
for col_index, ax in enumerate(row_axes):
if len(metrics) > 1:
if row == "metric":
xlabel = f"Decrease in {metrics[row_index]}"
elif col == "metric":
xlabel = f"Decrease in {metrics[col_index]}"
else:
xlabel = "Decrease in metric"
else:
xlabel = f"Decrease in {metrics[0]}"

for row_axes in self.facet_.axes:
for ax in row_axes:
_decorate_matplotlib_axis(
ax=ax,
add_background_features=add_background_features,
n_features=frame["feature"].nunique(),
xlabel=xlabel,
xlabel=f"Decrease in {metric_name}",
ylabel="",
)
if len(self.ax_.flatten()) == 1:
self.ax_ = self.ax_.flatten()[0]
data_source = frame["data_source"].unique()[0]
self.figure_.suptitle(
f"Permutation importance of {estimator_name} on {data_source} set"
f"Permutation importance {aggregate_title} \nof {estimator_name} "
f"on {data_source} set"
Comment on lines +377 to +378
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"Permutation importance {aggregate_title} \nof {estimator_name} "
f"on {data_source} set"
f"Permutation importance {aggregate_title}\n"
f"of {estimator_name} on {data_source} set"

)

def frame(
Expand All @@ -367,8 +355,11 @@ def frame(
Dataframe containing the importances.
"""
if self.report_type == "estimator":
columns_to_drop = ["estimator"]
columns_to_drop = ["estimator", "split"]
group_by = ["data_source", "metric", "feature"]
elif self.report_type == "cross-validation":
columns_to_drop = ["estimator"]
group_by = ["data_source", "metric", "split", "feature"]
else:
raise TypeError(f"Unexpected report type: {self.report_type!r}")

Expand All @@ -393,7 +384,7 @@ def frame(
frame = (
frame.drop(columns=["repetition"])
# avoid sorting the features by name and do not drop NA from
# output or labels in case of mixed metrics (i.e. averaged vs\
# output or labels in case of mixed metrics (i.e. averaged vs.
# non-averaged)
.groupby(group_by, sort=False, dropna=False)
.aggregate(aggregate)
Expand Down
Loading