-
Notifications
You must be signed in to change notification settings - Fork 118
feat(skore): Add cross-validation support for permutation importance #2370
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
0585afa
ec18274
b4ba8ac
3452928
08f3e29
de13664
52748eb
a50868b
3d5b296
bff5b2d
6d0de91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -26,12 +26,13 @@ class PermutationImportanceDisplay(DisplayMixin): | |||||||||
| - `estimator` | ||||||||||
| - `data_source` | ||||||||||
| - `metric` | ||||||||||
| - `split` | ||||||||||
| - `feature` | ||||||||||
| - `label` or `output` (classification vs. regression) | ||||||||||
| - `repetition` | ||||||||||
| - `value` | ||||||||||
|
|
||||||||||
| report_type : {"estimator"} | ||||||||||
| report_type : {"estimator", "cross-validation"} | ||||||||||
| Report type from which the display is created. | ||||||||||
|
|
||||||||||
| Attributes | ||||||||||
|
|
@@ -134,6 +135,7 @@ def _compute_data_for_display( | |||||||||
| "estimator", | ||||||||||
| "data_source", | ||||||||||
| "metric", | ||||||||||
| "split", | ||||||||||
| "feature", | ||||||||||
| "label", | ||||||||||
| "output", | ||||||||||
|
|
@@ -143,6 +145,7 @@ def _compute_data_for_display( | |||||||||
| df_importances = pd.concat(df_importances, axis="index") | ||||||||||
| df_importances["data_source"] = data_source | ||||||||||
| df_importances["estimator"] = estimator_name | ||||||||||
| df_importances["split"] = np.nan | ||||||||||
|
|
||||||||||
| return PermutationImportanceDisplay( | ||||||||||
| importances=df_importances[ordered_columns], report_type=report_type | ||||||||||
|
|
@@ -152,25 +155,28 @@ def _compute_data_for_display( | |||||||||
| def _get_columns_to_groupby(*, frame: pd.DataFrame) -> list[str]: | ||||||||||
| """Get the available columns from which to group by.""" | ||||||||||
| columns_to_groupby = list[str]() | ||||||||||
| if "metric" in frame.columns and frame["metric"].nunique() > 1: | ||||||||||
| columns_to_groupby.append("metric") | ||||||||||
| if "label" in frame.columns and frame["label"].nunique() > 1: | ||||||||||
| columns_to_groupby.append("label") | ||||||||||
| if "output" in frame.columns and frame["output"].nunique() > 1: | ||||||||||
| columns_to_groupby.append("output") | ||||||||||
| if "split" in frame.columns and frame["split"].nunique() > 1: | ||||||||||
| columns_to_groupby.append("split") | ||||||||||
| return columns_to_groupby | ||||||||||
|
|
||||||||||
| @DisplayMixin.style_plot | ||||||||||
| def plot( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| subplot_by: str | tuple[str, str] | None = "auto", | ||||||||||
| metric: str | list[str] | None = None, | ||||||||||
| metric: str, | ||||||||||
| subplot_by: str | tuple[str, ...] | None = "auto", | ||||||||||
| ) -> None: | ||||||||||
| """Plot the permutation importance. | ||||||||||
|
|
||||||||||
| Parameters | ||||||||||
| ---------- | ||||||||||
| metric : str | ||||||||||
| Metric to plot. | ||||||||||
|
|
||||||||||
|
Comment on lines
+221
to
+223
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's give it a |
||||||||||
| subplot_by : str, tuple of str or None, default="auto" | ||||||||||
| Column(s) to use for subplotting. The possible values are: | ||||||||||
|
|
||||||||||
|
|
@@ -179,40 +185,26 @@ def plot( | |||||||||
| - if a string, the corresponding column of the dataframe is used to create | ||||||||||
| several subplots. Those plots will be a organized in a grid of a single | ||||||||||
| row and several columns. | ||||||||||
| - if a tuple of strings, the corresponding columns of the dataframe are used | ||||||||||
| to create several subplots. Those plots will be a organized in a grid of | ||||||||||
| several rows and columns. The first element of the tuple is the row and | ||||||||||
| the second element is the column. | ||||||||||
| - if a tuple of 2 strings, the corresponding columns are used to create | ||||||||||
| subplots in a grid. The first element is the row, the second is the | ||||||||||
| column. | ||||||||||
| - if `None`, all information is plotted on a single plot. An error is raised | ||||||||||
| if there is too much information to plot on a single plot. | ||||||||||
|
|
||||||||||
| metric : str or list of str, default=None | ||||||||||
| Filter the importances by metric. If `None`, all importances associated with | ||||||||||
| each metric are plotted. | ||||||||||
| """ | ||||||||||
| return self._plot(subplot_by=subplot_by, metric=metric) | ||||||||||
|
|
||||||||||
| def _plot_matplotlib( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| subplot_by: str | tuple[str, str] | None = "auto", | ||||||||||
| metric: str | list[str] | None = None, | ||||||||||
| metric: str, | ||||||||||
| subplot_by: str | tuple[str, ...] | None = "auto", | ||||||||||
| ) -> None: | ||||||||||
| """Dispatch the plotting function for matplotlib backend.""" | ||||||||||
| boxplot_kwargs = self._default_boxplot_kwargs.copy() | ||||||||||
| stripplot_kwargs = self._default_stripplot_kwargs.copy() | ||||||||||
| frame = self.frame(metric=metric, aggregate=None) | ||||||||||
|
|
||||||||||
| err_msg = ( | ||||||||||
| "You try to plot the permutation importance of metrics averaged over {} " | ||||||||||
| "and other without averaging. This setting is not supported. Please filter " | ||||||||||
| "a group of consistent metrics using the `metric` parameter." | ||||||||||
| ) | ||||||||||
| if "label" in frame.columns and frame["label"].isna().any(): | ||||||||||
| raise ValueError(err_msg.format("labels")) | ||||||||||
| elif "output" in frame.columns and frame["output"].isna().any(): | ||||||||||
| raise ValueError(err_msg.format("outputs")) | ||||||||||
|
|
||||||||||
| self._plot_single_estimator( | ||||||||||
| subplot_by=subplot_by, | ||||||||||
| frame=frame, | ||||||||||
|
|
@@ -221,74 +213,79 @@ def _plot_matplotlib( | |||||||||
| stripplot_kwargs=stripplot_kwargs, | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| @staticmethod | ||||||||||
| def _aggregate_over_split(*, frame: pd.DataFrame) -> pd.DataFrame: | ||||||||||
| """Compute the averaged scores over the splits.""" | ||||||||||
| group_by = frame.columns.difference(["repetition", "value"]).tolist() | ||||||||||
| return ( | ||||||||||
| frame.drop(columns=["repetition"]) | ||||||||||
| .groupby(group_by, sort=False, dropna=False) | ||||||||||
| .aggregate("mean") | ||||||||||
| .reset_index() | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| def _plot_single_estimator( | ||||||||||
| self, | ||||||||||
| *, | ||||||||||
| subplot_by: str | tuple[str, str] | None, | ||||||||||
| subplot_by: str | tuple[str, ...] | None, | ||||||||||
| frame: pd.DataFrame, | ||||||||||
| estimator_name: str, | ||||||||||
| boxplot_kwargs: dict[str, Any], | ||||||||||
| stripplot_kwargs: dict[str, Any], | ||||||||||
| ) -> None: | ||||||||||
| """Plot the permutation importance for an `EstimatorReport`.""" | ||||||||||
| if subplot_by == "auto": | ||||||||||
| is_multi_metric = frame["metric"].nunique() > 1 | ||||||||||
| is_multi_target = any(name in frame.columns for name in ["label", "output"]) | ||||||||||
| if is_multi_metric and is_multi_target: | ||||||||||
| hue, col, row = ( | ||||||||||
| "label" if "label" in frame.columns else "output", | ||||||||||
| "metric", | ||||||||||
| None, | ||||||||||
| ) | ||||||||||
| elif is_multi_metric: | ||||||||||
| hue, col, row = None, "metric", None | ||||||||||
| elif is_multi_target: | ||||||||||
| hue, col, row = ( | ||||||||||
| "label" if "label" in frame.columns else "output", | ||||||||||
| None, | ||||||||||
| None, | ||||||||||
| ) | ||||||||||
| aggregate_title = "" | ||||||||||
| if subplot_by == "auto" or subplot_by is None: | ||||||||||
| columns_to_groupby = self._get_columns_to_groupby(frame=frame) | ||||||||||
|
|
||||||||||
| if "split" in columns_to_groupby: | ||||||||||
| frame = self._aggregate_over_split(frame=frame) | ||||||||||
| columns_to_groupby.remove("split") | ||||||||||
| aggregate_title = "averaged over splits" | ||||||||||
|
|
||||||||||
| if subplot_by is not None: | ||||||||||
| col = columns_to_groupby[0] if columns_to_groupby else None | ||||||||||
| hue, row = None, None | ||||||||||
| else: | ||||||||||
| hue, col, row = None, None, None | ||||||||||
| elif subplot_by is None: | ||||||||||
| # Possible accepted values: {"metric"}, {"label"}, {"output"} | ||||||||||
| hue = columns_to_groupby[0] if columns_to_groupby else None | ||||||||||
| col, row = None, None | ||||||||||
|
|
||||||||||
| else: | ||||||||||
| columns_to_groupby = self._get_columns_to_groupby(frame=frame) | ||||||||||
| n_columns_to_groupby = len(columns_to_groupby) | ||||||||||
| if n_columns_to_groupby > 1: | ||||||||||
| subplot_cols = (subplot_by,) if isinstance(subplot_by, str) else subplot_by | ||||||||||
| invalid = [c for c in subplot_cols if c not in columns_to_groupby] | ||||||||||
| if invalid: | ||||||||||
| raise ValueError( | ||||||||||
| "Cannot plot all the available information available on a single " | ||||||||||
| "plot. Please set `subplot_by` to a string or a tuple of strings. " | ||||||||||
| "You can use the following values to create subplots: " | ||||||||||
| f"The column(s) {invalid} are not available. You can use the " | ||||||||||
| "following values to create subplots: " | ||||||||||
| f"{', '.join(columns_to_groupby)}" | ||||||||||
| ) | ||||||||||
| elif n_columns_to_groupby == 1: | ||||||||||
| hue, col, row = columns_to_groupby[0], None, None | ||||||||||
| else: | ||||||||||
| hue, col, row = None, None, None | ||||||||||
| else: | ||||||||||
| # Possible accepted values: {"metric"}, {"metric", "label"}, | ||||||||||
| # {"metric", "output"} | ||||||||||
| columns_to_groupby = self._get_columns_to_groupby(frame=frame) | ||||||||||
| if isinstance(subplot_by, str): | ||||||||||
| if subplot_by not in columns_to_groupby: | ||||||||||
| raise ValueError( | ||||||||||
| f"The column {subplot_by} is not available. You can use the " | ||||||||||
| "following values to create subplots: " | ||||||||||
| f"{', '.join(columns_to_groupby)}" | ||||||||||
|
|
||||||||||
| remaining = set(columns_to_groupby) - set(subplot_cols) | ||||||||||
| if "split" in remaining: | ||||||||||
| frame = self._aggregate_over_split(frame=frame) | ||||||||||
| remaining.remove("split") | ||||||||||
| aggregate_title = "averaged over splits" | ||||||||||
|
|
||||||||||
| match len(subplot_cols): | ||||||||||
| case 1: | ||||||||||
| col, row, hue = ( | ||||||||||
| subplot_cols[0], | ||||||||||
| None, | ||||||||||
| (next(iter(remaining)) if remaining else None), | ||||||||||
| ) | ||||||||||
| col, row = subplot_by, None | ||||||||||
| if remaining_column := set(columns_to_groupby) - {subplot_by}: | ||||||||||
| hue = next(iter(remaining_column)) | ||||||||||
| else: | ||||||||||
| hue = None | ||||||||||
| else: | ||||||||||
| if not all(item in columns_to_groupby for item in subplot_by): | ||||||||||
| case 2: | ||||||||||
| row, col, hue = ( | ||||||||||
| subplot_cols[0], | ||||||||||
| subplot_cols[1], | ||||||||||
| (next(iter(remaining)) if remaining else None), | ||||||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit
Suggested change
Comment on lines
+315
to
+325
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Both |
||||||||||
| ) | ||||||||||
| case _: | ||||||||||
| raise ValueError( | ||||||||||
| f"The columns {subplot_by} are not available. You can use the " | ||||||||||
| "following values to create subplots: " | ||||||||||
| "Expected 1 to 2 columns for subplot_by, got " | ||||||||||
| f"{len(subplot_cols)}. You can use the following values: " | ||||||||||
| f"{', '.join(columns_to_groupby)}" | ||||||||||
| ) | ||||||||||
| (row, col), hue = subplot_by, None | ||||||||||
|
|
||||||||||
| if hue is None: | ||||||||||
| # we don't need the palette and we are at risk of raising an error or | ||||||||||
|
|
@@ -316,32 +313,23 @@ def _plot_single_estimator( | |||||||||
| ) | ||||||||||
| add_background_features = hue is not None | ||||||||||
|
|
||||||||||
| metrics = frame["metric"].unique() | ||||||||||
| metric_name = frame["metric"].unique()[0] | ||||||||||
| self.figure_, self.ax_ = self.facet_.figure, self.facet_.axes.squeeze() | ||||||||||
| for row_index, row_axes in enumerate(self.facet_.axes): | ||||||||||
| for col_index, ax in enumerate(row_axes): | ||||||||||
| if len(metrics) > 1: | ||||||||||
| if row == "metric": | ||||||||||
| xlabel = f"Decrease in {metrics[row_index]}" | ||||||||||
| elif col == "metric": | ||||||||||
| xlabel = f"Decrease in {metrics[col_index]}" | ||||||||||
| else: | ||||||||||
| xlabel = "Decrease in metric" | ||||||||||
| else: | ||||||||||
| xlabel = f"Decrease in {metrics[0]}" | ||||||||||
|
|
||||||||||
| for row_axes in self.facet_.axes: | ||||||||||
| for ax in row_axes: | ||||||||||
| _decorate_matplotlib_axis( | ||||||||||
| ax=ax, | ||||||||||
| add_background_features=add_background_features, | ||||||||||
| n_features=frame["feature"].nunique(), | ||||||||||
| xlabel=xlabel, | ||||||||||
| xlabel=f"Decrease in {metric_name}", | ||||||||||
| ylabel="", | ||||||||||
| ) | ||||||||||
| if len(self.ax_.flatten()) == 1: | ||||||||||
| self.ax_ = self.ax_.flatten()[0] | ||||||||||
| data_source = frame["data_source"].unique()[0] | ||||||||||
| self.figure_.suptitle( | ||||||||||
| f"Permutation importance of {estimator_name} on {data_source} set" | ||||||||||
| f"Permutation importance {aggregate_title} \nof {estimator_name} " | ||||||||||
| f"on {data_source} set" | ||||||||||
|
Comment on lines
+377
to
+378
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| ) | ||||||||||
|
|
||||||||||
| def frame( | ||||||||||
|
|
@@ -367,8 +355,11 @@ def frame( | |||||||||
| Dataframe containing the importances. | ||||||||||
| """ | ||||||||||
| if self.report_type == "estimator": | ||||||||||
| columns_to_drop = ["estimator"] | ||||||||||
| columns_to_drop = ["estimator", "split"] | ||||||||||
| group_by = ["data_source", "metric", "feature"] | ||||||||||
| elif self.report_type == "cross-validation": | ||||||||||
| columns_to_drop = ["estimator"] | ||||||||||
| group_by = ["data_source", "metric", "split", "feature"] | ||||||||||
| else: | ||||||||||
| raise TypeError(f"Unexpected report type: {self.report_type!r}") | ||||||||||
|
|
||||||||||
|
|
@@ -393,7 +384,7 @@ def frame( | |||||||||
| frame = ( | ||||||||||
| frame.drop(columns=["repetition"]) | ||||||||||
| # avoid sorting the features by name and do not drop NA from | ||||||||||
| # output or labels in case of mixed metrics (i.e. averaged vs\ | ||||||||||
| # output or labels in case of mixed metrics (i.e. averaged vs. | ||||||||||
| # non-averaged) | ||||||||||
| .groupby(group_by, sort=False, dropna=False) | ||||||||||
| .aggregate(aggregate) | ||||||||||
|
|
||||||||||
Uh oh!
There was an error while loading. Please reload this page.