Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/plot_conditional_vs_marginal_xor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
random_state=0,
)
vim.fit(X_train, y_train)
importances.append(vim.importance(X_test, y_test)["importance"])
importances.append(vim.importance(X_test, y_test))

importances = np.array(importances).T

Expand Down
12 changes: 6 additions & 6 deletions examples/plot_diabetes_variable_importance_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,14 @@ def compute_pval(vim):
# -------------------


cfi_vim_arr = np.array([x["importance"] for x in cfi_importance_list]) / 2
cfi_vim_arr = np.array(cfi_importance_list) / 2
cfi_pval = compute_pval(cfi_vim_arr)

vim = [
pd.DataFrame(
{
"var": np.arange(cfi_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": cfi_pval,
"method": "CFI",
Expand All @@ -200,14 +200,14 @@ def compute_pval(vim):
for x in cfi_importance_list
]

loco_vim_arr = np.array([x["importance"] for x in loco_importance_list])
loco_vim_arr = np.array(loco_importance_list)
loco_pval = compute_pval(loco_vim_arr)

vim += [
pd.DataFrame(
{
"var": np.arange(loco_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": loco_pval,
"method": "LOCO",
Expand All @@ -216,14 +216,14 @@ def compute_pval(vim):
for x in loco_importance_list
]

pfi_vim_arr = np.array([x["importance"] for x in pfi_importance_list])
pfi_vim_arr = np.array(pfi_importance_list)
pfi_pval = compute_pval(pfi_vim_arr)

vim += [
pd.DataFrame(
{
"var": np.arange(pfi_vim_arr.shape[1]),
"importance": x["importance"],
"importance": x,
"fold": i,
"pval": pfi_pval,
"method": "PFI",
Expand Down
2 changes: 1 addition & 1 deletion examples/plot_importance_classification_iris.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def run_one_fold(X, y, model, train_index, test_index, vim_name="CFI", groups=No
)

vim.fit(X[train_index], y[train_index], groups=groups)
importance = vim.importance(X[test_index], y[test_index])["importance"]
importance = vim.importance(X[test_index], y[test_index])

return pd.DataFrame(
{
Expand Down
6 changes: 2 additions & 4 deletions examples/plot_model_agnostic_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@
vim_linear.fit(X[train], y[train])
vim_non_linear.fit(X[train], y[train])

importances_linear.append(vim_linear.importance(X[test], y[test])["importance"])
importances_non_linear.append(
vim_non_linear.importance(X[test], y[test])["importance"]
)
importances_linear.append(vim_linear.importance(X[test], y[test]))
importances_non_linear.append(vim_non_linear.importance(X[test], y[test]))


################################################################################
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_pitfalls_permutation_importance.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
)
pfi.fit(X_test, y_test)

permutation_importances.append(pfi.importance(X_test, y_test)["importance"])
permutation_importances.append(pfi.importance(X_test, y_test))
permutation_importances = np.stack(permutation_importances)
pval_pfi = ttest_1samp(
permutation_importances, 0.0, axis=0, alternative="greater"
Expand Down Expand Up @@ -200,7 +200,7 @@
)
cfi.fit(X_test, y_test)

conditional_importances.append(cfi.importance(X_test, y_test)["importance"])
conditional_importances.append(cfi.importance(X_test, y_test))


cfi_pval = ttest_1samp(
Expand Down
9 changes: 6 additions & 3 deletions src/hidimstat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
desparsified_group_lasso_pvalue,
)
from .distilled_conditional_randomization_test import d0crt, D0CRT
from .conditional_feature_importance import CFI
from .conditional_feature_importance import cfi, CFI
from .knockoffs import (
model_x_knockoff,
model_x_knockoff_pvalue,
model_x_knockoff_bootstrap_quantile,
model_x_knockoff_bootstrap_e_value,
)
from .leave_one_covariate_out import LOCO
from .leave_one_covariate_out import loco, LOCO
from .noise_std import reid
from .permutation_feature_importance import PFI
from .permutation_feature_importance import pfi, PFI

from .statistical_tools.aggregation import quantile_aggregation

Expand All @@ -47,6 +47,9 @@
"model_x_knockoff_bootstrap_quantile",
"model_x_knockoff_bootstrap_e_value",
"CFI",
"cfi",
"LOCO",
"loco",
"PFI",
"pfi",
]
31 changes: 31 additions & 0 deletions src/hidimstat/_utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,34 @@ def _check_vim_predict_method(method):
"The method {} is not a valid method "
"for variable importance measure prediction".format(method)
)


def get_generated_attributes(cls):
"""
Get all attributes from a class that end with a single underscore
and doesn't start with one underscore.

Parameters
----------
cls : class
The class to inspect for attributes.

Returns
-------
list
A list of attribute names that end with a single underscore but not double underscore.
"""
# Get all attributes and methods of the class
all_attributes = dir(cls)

# Filter out attributes that start with an underscore
filtered_attributes = [attr for attr in all_attributes if not attr.startswith("_")]

# Filter out attributes that do not end with a single underscore
result = [
attr
for attr in filtered_attributes
if attr.endswith("_") and not attr.endswith("__")
]

return result
Loading
Loading