Skip to content

Commit 08565c0

Browse files
committed
fix tests
1 parent 508ce53 commit 08565c0

File tree

6 files changed

+37
-30
lines changed

6 files changed

+37
-30
lines changed

src/hidimstat/conditional_feature_importance.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -134,9 +134,9 @@ def fit(self, X, y=None, features_groups=None, features_type="auto"):
134134
X_ = np.asarray(X)
135135
self._list_imputation_models = Parallel(n_jobs=self.n_jobs)(
136136
delayed(self._joblib_fit_one_features_group)(
137-
estimator, X_, features_groups_ids
137+
imputation_model, X_, features_groups_ids
138138
)
139-
for features_groups_ids, estimator in zip(
139+
for features_groups_ids, imputation_model in zip(
140140
self._features_groups_ids, self._list_imputation_models
141141
)
142142
)

src/hidimstat/leave_one_covariate_out.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,9 @@ def _joblib_fit_one_features_group(self, estimator, X, y, key_features_groups):
9999
estimator.fit(X_minus_j, y)
100100
return estimator
101101

102-
def _joblib_predict_one_group(self, X, features_group_id, key_features_groups):
102+
def _joblib_predict_one_features_group(
103+
self, X, features_group_id, key_features_groups
104+
):
103105
"""Predict the target feature after removing a group of covariates.
104106
Used in parallel."""
105107
X_minus_j = np.delete(X, self._features_groups_ids[features_group_id], axis=1)

test/test_base_perturbation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,4 @@ def test_no_implemented_methods():
1111
estimator.fit(X[:, 0], X[:, 1])
1212
basic_class = BasePerturbation(estimator=estimator)
1313
with pytest.raises(NotImplementedError):
14-
basic_class._permutation(X, group_id=None)
14+
basic_class._permutation(X, features_group_id=None)

test/test_conditional_feature_importance.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ def run_cfi(X, y, n_permutation, seed):
6262
# fit the model using the training set
6363
cfi.fit(
6464
X_train,
65-
groups=None,
66-
var_type="auto",
65+
features_groups=None,
66+
features_type="auto",
6767
)
6868
# calculate feature importance using the test set
6969
vim = cfi.importance(X_test, y_test)
@@ -194,8 +194,8 @@ def test_group(data_generator):
194194
)
195195
cfi.fit(
196196
X_train_df,
197-
groups=groups,
198-
var_type="continuous",
197+
features_groups=groups,
198+
features_type="continuous",
199199
)
200200
# Warning expected since column names in pandas are not considered
201201
with pytest.warns(UserWarning, match="X does not have valid feature names, but"):
@@ -245,8 +245,8 @@ def test_classication(data_generator):
245245
)
246246
cfi.fit(
247247
X_train,
248-
groups=None,
249-
var_type=["continuous"] * X.shape[1],
248+
features_groups=None,
249+
features_type=["continuous"] * X.shape[1],
250250
)
251251
vim = cfi.importance(X_test, y_test_clf)
252252
importance = vim["importance"]
@@ -297,13 +297,13 @@ def test_fit(self, data_generator):
297297
# Test fit with auto var_type
298298
cfi.fit(X)
299299
assert len(cfi._list_imputation_models) == X.shape[1]
300-
assert cfi.n_groups == X.shape[1]
300+
assert cfi.n_features_groups == X.shape[1]
301301

302302
# Test fit with specified groups
303303
groups = {"g1": [0, 1], "g2": [2, 3, 4]}
304-
cfi.fit(X, groups=groups)
304+
cfi.fit(X, features_groups=groups)
305305
assert len(cfi._list_imputation_models) == 2
306-
assert cfi.n_groups == 2
306+
assert cfi.n_features_groups == 2
307307

308308
def test_categorical(
309309
self,
@@ -331,8 +331,8 @@ def test_categorical(
331331
random_state=seed + 1,
332332
)
333333

334-
var_type = ["continuous", "continuous", "categorical"]
335-
cfi.fit(X, y, var_type=var_type)
334+
features_type = ["continuous", "continuous", "categorical"]
335+
cfi.fit(X, y, features_type=features_type)
336336

337337
importances = cfi.importance(X, y)["importance"]
338338
assert len(importances) == 3
@@ -415,7 +415,7 @@ def test_invalid_type(self, data_generator):
415415

416416
# Test error when passing invalid var_type
417417
with pytest.raises(ValueError, match="type of data 'invalid' unknow."):
418-
cfi.fit(X, var_type="invalid")
418+
cfi.fit(X, features_type="invalid")
419419

420420
def test_invalid_n_permutations(self, data_generator):
421421
"""Test when invalid number of permutations is provided"""
@@ -434,7 +434,7 @@ def test_not_good_type_X(self, data_generator):
434434
imputation_model_continuous=LinearRegression(),
435435
method="predict",
436436
)
437-
cfi.fit(X, groups=None, var_type="auto")
437+
cfi.fit(X, features_groups=None, features_type="auto")
438438

439439
with pytest.raises(
440440
ValueError, match="X should be a pandas dataframe or a numpy array."
@@ -450,7 +450,7 @@ def test_mismatched_features(self, data_generator):
450450
imputation_model_continuous=LinearRegression(),
451451
method="predict",
452452
)
453-
cfi.fit(X, groups=None, var_type="auto")
453+
cfi.fit(X, features_groups=None, features_type="auto")
454454

455455
with pytest.raises(
456456
AssertionError, match="X does not correspond to the fitting data."
@@ -473,7 +473,7 @@ def test_mismatched_features_string(self, data_generator):
473473
"col_" + str(i) for i in range(int(X.shape[1] / 2), X.shape[1] - 3)
474474
],
475475
}
476-
cfi.fit(X, groups=subgroups, var_type="auto")
476+
cfi.fit(X, features_groups=subgroups, features_type="auto")
477477

478478
with pytest.raises(
479479
AssertionError,
@@ -499,8 +499,8 @@ def test_internal_error(self, data_generator):
499499
"col_" + str(i) for i in range(int(X.shape[1] / 2), X.shape[1] - 3)
500500
],
501501
}
502-
cfi.fit(X, groups=subgroups, var_type="auto")
503-
cfi.groups["group1"] = [None for i in range(100)]
502+
cfi.fit(X, features_groups=subgroups, features_type="auto")
503+
cfi.features_groups["group1"] = [None for i in range(100)]
504504

505505
X = X.to_records(index=False)
506506
X = np.array(X, dtype=X.dtype.descr)
@@ -517,7 +517,9 @@ def test_invalid_var_type(self, data_generator):
517517
cfi = CFI(estimator=fitted_model, method="predict")
518518

519519
with pytest.raises(ValueError, match="type of data 'invalid_type' unknow."):
520-
cfi.fit(X, groups=None, var_type=["invalid_type"] * X.shape[1])
520+
cfi.fit(
521+
X, features_groups=None, features_type=["invalid_type"] * X.shape[1]
522+
)
521523

522524
def test_incompatible_imputer(self, data_generator):
523525
"""Test when incompatible imputer is provided"""
@@ -548,7 +550,7 @@ def test_invalid_groups_format(self, data_generator):
548550

549551
invalid_groups = ["group1", "group2"] # Should be dictionary
550552
with pytest.raises(ValueError, match="groups needs to be a dictionnary"):
551-
cfi.fit(X, groups=invalid_groups, var_type="auto")
553+
cfi.fit(X, features_groups=invalid_groups, features_type="auto")
552554

553555
def test_groups_warning(self, data_generator):
554556
"""Test if a subgroup raise a warning"""
@@ -560,7 +562,7 @@ def test_groups_warning(self, data_generator):
560562
method="predict",
561563
)
562564
subgroups = {"group1": [0, 1], "group2": [2, 3]}
563-
cfi.fit(X, y, groups=subgroups, var_type="auto")
565+
cfi.fit(X, y, features_groups=subgroups, features_type="auto")
564566

565567
with pytest.warns(
566568
UserWarning,

test/test_leave_one_covariate_out.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_loco():
3636
loco.fit(
3737
X_train,
3838
y_train,
39-
groups=None,
39+
features_groups=None,
4040
)
4141
vim = loco.importance(X_test, y_test)
4242

@@ -63,7 +63,7 @@ def test_loco():
6363
loco.fit(
6464
X_train_df,
6565
y_train,
66-
groups=groups,
66+
features_groups=groups,
6767
)
6868
# warnings because we doesn't considere the name of columns of pandas
6969
with pytest.warns(UserWarning, match="X does not have valid feature names, but"):
@@ -87,7 +87,10 @@ def test_loco():
8787
loco_clf.fit(
8888
X_train,
8989
y_train_clf,
90-
groups={"group_0": important_features, "the_group_1": non_important_features},
90+
features_groups={
91+
"group_0": important_features,
92+
"the_group_1": non_important_features,
93+
},
9194
)
9295
vim_clf = loco_clf.importance(X_test, y_test_clf)
9396

test/test_permutation_feature_importance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def test_permutation_importance():
3737
pfi.fit(
3838
X_train,
3939
y_train,
40-
groups=None,
40+
features_groups=None,
4141
)
4242
vim = pfi.importance(X_test, y_test)
4343

@@ -66,7 +66,7 @@ def test_permutation_importance():
6666
pfi.fit(
6767
X_train_df,
6868
y_train,
69-
groups=groups,
69+
features_groups=groups,
7070
)
7171
# warnings because we doesn't considere the name of columns of pandas
7272
with pytest.warns(UserWarning, match="X does not have valid feature names, but"):
@@ -93,7 +93,7 @@ def test_permutation_importance():
9393
pfi_clf.fit(
9494
X_train,
9595
y_train_clf,
96-
groups=None,
96+
features_groups=None,
9797
)
9898
vim_clf = pfi_clf.importance(X_test, y_test_clf)
9999

0 commit comments

Comments
 (0)