Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/src/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,18 @@ Classes

BaseVariableImportance
BasePerturbation
VariableImportanceFeatureGroup
LOCO
CFI
PFI
D0CRT

Marginal Importance
===================

.. autosummary::
:toctree: ./generated/api/marginal
:template: class.rst

LOCI
LeaveOneCovariateIn
9 changes: 9 additions & 0 deletions docs/tools/references.bib
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,15 @@ @article{eshel2003yule
year = {2003}
}

@inproceedings{ewald2024guide,
title = {A guide to feature importance methods for scientific inference},
author = {Ewald, Fiona Katharina and Bothmann, Ludwig and Wright, Marvin N and Bischl, Bernd and Casalicchio, Giuseppe and K{\"o}nig, Gunnar},
booktitle = {World Conference on Explainable Artificial Intelligence},
pages = {440--464},
year = {2024},
organization = {Springer}
}

@article{fan2012variance,
author = {Fan, Jianqing and Guo, Shaojun and Hao, Ning},
journal = {Journal of the Royal Statistical Society Series B: Statistical Methodology},
Expand Down
24 changes: 6 additions & 18 deletions examples/plot_conditional_vs_marginal_xor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import seaborn as sns
from sklearn.base import clone
from sklearn.linear_model import RidgeCV
from sklearn.metrics import hinge_loss
from sklearn.metrics import hinge_loss, accuracy_score
from sklearn.model_selection import KFold, train_test_split
from sklearn.svm import SVC

from hidimstat import CFI
from hidimstat import CFI, LOCI

#############################################################################
# To solve the XOR problem, we will use a Support Vector Classier (SVC) with Radial Basis Function (RBF) kernel. The decision function of
Expand Down Expand Up @@ -82,21 +82,9 @@
cv = KFold(n_splits=5, shuffle=True, random_state=0)
clf = SVC(kernel="rbf", random_state=0)
# Compute marginal importance using univariate models
marginal_scores = []
for i in range(X.shape[1]):
feat_scores = []
for train_index, test_index in cv.split(X):
X_train, X_test = X[train_index], X[test_index]
y_train, y_test = Y[train_index], Y[test_index]

X_train_univariate = X_train[:, i].reshape(-1, 1)
X_test_univariate = X_test[:, i].reshape(-1, 1)

univariate_model = clone(clf)
univariate_model.fit(X_train_univariate, y_train)

feat_scores.append(univariate_model.score(X_test_univariate, y_test))
marginal_scores.append(feat_scores)
loci = LOCI(estimator=clone(clf).fit(X, Y), method="decision_function", loss=hinge_loss)
mean_importances = loci.fit_importance(X, Y, cv=cv)
marginal_importances = np.array(loci.importances_)

###########################################################################

Expand Down Expand Up @@ -129,7 +117,7 @@
fig, axes = plt.subplots(1, 2, sharey=True, figsize=(6, 2.5))
# Marginal scores boxplot
sns.boxplot(
data=np.array(marginal_scores).T,
data=marginal_importances,
orient="h",
ax=axes[0],
fill=False,
Expand Down
34 changes: 24 additions & 10 deletions examples/plot_importance_classification_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,9 @@
# require a K-fold cross-fitting. Computing the importance for each fold is
# embarassingly parallel. For this reason, we encapsulate the main computations in a
# function and use joblib to parallelize the computation.
def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=None):
def run_one_fold(
X, y, model, train_index, test_index, vim_name="CFI", features_groups=None
):
model_c = clone(model)
model_c.fit(X[train_index], y[train_index])
y_pred = model_c.predict(X[test_index])
Expand Down Expand Up @@ -92,12 +94,12 @@ def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=No
loss=loss,
)

vim.fit(X[train_index], y[train_index], groups=groups)
vim.fit(X[train_index], y[train_index], features_groups=features_groups)
importance = vim.importance(X[test_index], y[test_index])["importance"]

return pd.DataFrame(
{
"feature": groups.keys(),
"feature": features_groups.keys(),
"importance": importance,
"vim": vim_name,
"model": model_name,
Expand All @@ -116,10 +118,16 @@ def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=No
GridSearchCV(SVC(kernel="rbf"), {"C": np.logspace(-3, 3, 10)}),
]
cv = KFold(n_splits=5, shuffle=True, random_state=0)
groups = {ft: [i] for i, ft in enumerate(dataset.feature_names)}
features_groups = {ft: [i] for i, ft in enumerate(dataset.feature_names)}
out_list = Parallel(n_jobs=5)(
delayed(run_one_fold)(
X, y, model, train_index, test_index, vim_name=vim_name, groups=groups
X,
y,
model,
train_index,
test_index,
vim_name=vim_name,
features_groups=features_groups,
)
for train_index, test_index in cv.split(X)
for model in models
Expand Down Expand Up @@ -255,16 +263,22 @@ def plot_results(df_importance, df_pval):
# mitigate this issue, we can group correlated features together and measure the
# importance of these feature groups. For instance, we can group 'sepal width' with
# 'sepal length' and 'petal length' with 'petal width' and the spurious feature.
groups = {"sepal features": [0, 1], "petal features": [2, 3, 4]}
features_groups = {"sepal features": [0, 1], "petal features": [2, 3, 4]}
out_list = Parallel(n_jobs=5)(
delayed(run_one_fold)(
X, y, model, train_index, test_index, vim_name=vim_name, groups=groups
X,
y,
model,
train_index,
test_index,
vim_name=vim_name,
features_groups=features_groups,
)
for train_index, test_index in cv.split(X)
for model in models
for vim_name in ["CFI", "PFI"]
)

df_grouped = pd.concat(out_list)
df_pval = compute_pval(df_grouped, threshold=threshold)
plot_results(df_grouped, df_pval)
df_features_grouped = pd.concat(out_list)
df_pval = compute_pval(df_features_grouped, threshold=threshold)
plot_results(df_features_grouped, df_pval)
11 changes: 10 additions & 1 deletion src/hidimstat/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .base_variable_importance import BaseVariableImportance
from .base_variable_importance import (
BaseVariableImportance,
VariableImportanceFeatureGroup,
)
from .base_perturbation import BasePerturbation
from .ensemble_clustered_inference import (
clustered_inference,
Expand All @@ -25,6 +28,10 @@
from .noise_std import reid
from .permutation_feature_importance import PFI

# marginal methods
from .marginal import LeaveOneCovariateIn # for having documentation
from .marginal import LeaveOneCovariateIn as LOCI

from .statistical_tools.aggregation import quantile_aggregation

try:
Expand All @@ -51,4 +58,6 @@
"CFI",
"LOCO",
"PFI",
# marginal methods
"LOCI",
]
146 changes: 23 additions & 123 deletions src/hidimstat/base_perturbation.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
from sklearn.base import check_is_fitted
from sklearn.metrics import root_mean_squared_error
import warnings

from hidimstat._utils.utils import _check_vim_predict_method
from hidimstat._utils.exception import InternalError
from hidimstat.base_variable_importance import BaseVariableImportance
from hidimstat.base_variable_importance import (
BaseVariableImportance,
VariableImportanceFeatureGroup,
)


class BasePerturbation(BaseVariableImportance):
class BasePerturbation(BaseVariableImportance, VariableImportanceFeatureGroup):
def __init__(
self,
estimator,
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(
The number of parallel jobs to run. Parallelization is done over the
variables or groups of variables.
"""
super().__init__()
check_is_fitted(estimator)
assert n_permutations > 0, "n_permutations must be positive"
self.estimator = estimator
Expand All @@ -51,45 +52,6 @@ def __init__(
self.method = method
self.n_jobs = n_jobs
self.n_permutations = n_permutations
self.n_groups = None

def fit(self, X, y=None, groups=None):
"""Base fit method for perturbation-based methods. Identifies the groups.

Parameters
----------
X: array-like of shape (n_samples, n_features)
The input samples.
y: array-like of shape (n_samples,)
Not used, only present for consistency with the sklearn API.
groups: dict, optional
A dictionary where the keys are the group names and the values are the
list of column names corresponding to each group. If None, the groups are
identified based on the columns of X.
"""
if groups is None:
self.n_groups = X.shape[1]
self.groups = {j: [j] for j in range(self.n_groups)}
self._groups_ids = np.array(list(self.groups.values()), dtype=int)
elif isinstance(groups, dict):
self.n_groups = len(groups)
self.groups = groups
if isinstance(X, pd.DataFrame):
self._groups_ids = []
for group_key in self.groups.keys():
self._groups_ids.append(
[
i
for i, col in enumerate(X.columns)
if col in self.groups[group_key]
]
)
else:
self._groups_ids = [
np.array(ids, dtype=int) for ids in list(self.groups.values())
]
else:
raise ValueError("groups needs to be a dictionnary")

def predict(self, X):
"""
Expand All @@ -111,8 +73,12 @@ def predict(self, X):

# Parallelize the computation of the importance scores for each group
out_list = Parallel(n_jobs=self.n_jobs)(
delayed(self._joblib_predict_one_group)(X_, group_id, group_key)
for group_id, group_key in enumerate(self.groups.keys())
delayed(self._joblib_predict_one_features_group)(
X_, features_group_id, features_group_key
)
for features_group_id, features_group_key in enumerate(
self.features_groups.keys()
)
)
return np.stack(out_list, axis=0)

Expand Down Expand Up @@ -155,82 +121,14 @@ def importance(self, X, y):
out_dict["importance"] = np.array(
[
np.mean(out_dict["loss"][j]) - loss_reference
for j in range(self.n_groups)
for j in range(self.n_features_groups)
]
)
return out_dict

def _check_fit(self, X):
"""
Check if the perturbation method has been properly fitted.

This method verifies that the perturbation method has been fitted by checking
if required attributes are set and if the number of features matches
the grouped variables.

Parameters
----------
X : array-like of shape (n_samples, n_features)
Input data to validate against the fitted model.

Raises
------
ValueError
If the method has not been fitted (i.e., if n_groups, groups,
or _groups_ids attributes are missing).
AssertionError
If the number of features in X does not match the total number
of features in the grouped variables.
"""
if (
self.n_groups is None
or not hasattr(self, "groups")
or not hasattr(self, "_groups_ids")
):
raise ValueError(
"The class is not fitted. The fit method must be called"
" to set variable groups. If no grouping is needed,"
" call fit with groups=None"
)
if isinstance(X, pd.DataFrame):
names = list(X.columns)
elif isinstance(X, np.ndarray) and X.dtype.names is not None:
names = X.dtype.names
# transform Structured Array in pandas array for a better manipulation
X = pd.DataFrame(X)
elif isinstance(X, np.ndarray):
names = None
else:
raise ValueError("X should be a pandas dataframe or a numpy array.")
number_columns = X.shape[1]
for index_variables in self.groups.values():
if type(index_variables[0]) is int or np.issubdtype(
type(index_variables[0]), int
):
assert np.all(
np.array(index_variables, dtype=int) < number_columns
), "X does not correspond to the fitting data."
elif type(index_variables[0]) is str or np.issubdtype(
type(index_variables[0]), str
):
assert np.all(
[name in names for name in index_variables]
), f"The array is missing at least one of the following columns {index_variables}."
else:
raise InternalError(
"A problem with indexing has happened during the fit."
)
number_unique_feature_in_groups = np.unique(
np.concatenate([values for values in self.groups.values()])
).shape[0]
if X.shape[1] != number_unique_feature_in_groups:
warnings.warn(
f"The number of features in X: {X.shape[1]} differs from the"
" number of features for which importance is computed: "
f"{number_unique_feature_in_groups}"
)

def _joblib_predict_one_group(self, X, group_id, group_key):
def _joblib_predict_one_features_group(
self, X, features_group_id, features_group_key
):
"""
Compute the predictions after perturbation of the data for a given
group of variables. This function is parallelized.
Expand All @@ -244,13 +142,15 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
group_key: str, int
The key of the group of variables. (parameter use for debugging)
"""
group_ids = self._groups_ids[group_id]
non_group_ids = np.delete(np.arange(X.shape[1]), group_ids)
features_group_ids = self._features_groups_ids[features_group_id]
non_features_group_ids = np.delete(np.arange(X.shape[1]), features_group_ids)
# Create an array X_perm_j of shape (n_permutations, n_samples, n_features)
# where the j-th group of covariates is permuted
X_perm = np.empty((self.n_permutations, X.shape[0], X.shape[1]))
X_perm[:, :, non_group_ids] = np.delete(X, group_ids, axis=1)
X_perm[:, :, group_ids] = self._permutation(X, group_id=group_id)
X_perm[:, :, non_features_group_ids] = np.delete(X, features_group_ids, axis=1)
X_perm[:, :, features_group_ids] = self._permutation(
X, features_group_id=features_group_id
)
# Reshape X_perm to allow for batch prediction
X_perm_batch = X_perm.reshape(-1, X.shape[1])
y_pred_perm = getattr(self.estimator, self.method)(X_perm_batch)
Expand All @@ -264,6 +164,6 @@ def _joblib_predict_one_group(self, X, group_id, group_key):
)
return y_pred_perm

def _permutation(self, X, group_id):
def _permutation(self, X, features_group_id):
"""Method for creating the permuted data for the j-th group of covariates."""
raise NotImplementedError
Loading
Loading