Skip to content

Commit b405d79

Browse files
committed
Merge branch 'PR_gaussian_reformat' into PR_CRT
2 parents a2c830c + 9429418 commit b405d79

File tree

9 files changed

+288
-293
lines changed

9 files changed

+288
-293
lines changed

examples/plot_pitfalls_permutation_importance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from sklearn.preprocessing import StandardScaler
2727

2828
from hidimstat import CFI, PFI
29-
from hidimstat.conditional_sampling import ConditionalSampler
29+
from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler
3030

3131
rng = np.random.RandomState(0)
3232

src/hidimstat/conditional_feature_importance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from sklearn.utils.validation import check_random_state
66

77
from hidimstat.base_perturbation import BasePerturbation
8-
from hidimstat.conditional_sampling import ConditionalSampler
8+
from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler
99

1010

1111
class CFI(BasePerturbation):

src/hidimstat/gaussian_knockoff.py

Lines changed: 0 additions & 191 deletions
This file was deleted.

src/hidimstat/knockoffs.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from sklearn.utils import check_random_state
88
from sklearn.utils.validation import check_memory
99

10-
from hidimstat.gaussian_knockoff import (
11-
gaussian_knockoff_generation,
12-
repeat_gaussian_knockoff_generation,
13-
)
10+
from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution
1411
from hidimstat.statistical_tools.multiple_testing import fdr_threshold
1512
from hidimstat.statistical_tools.aggregation import quantile_aggregation
1613

@@ -196,32 +193,12 @@ def model_x_knockoff(
196193
raise TypeError("Wrong type for random_state")
197194
seed_list = rng.randint(1, np.iinfo(np.int32).max, n_bootstraps)
198195

199-
if centered:
200-
X = StandardScaler().fit_transform(X)
201-
202-
# estimation of X distribution
203-
# original implementation:
204-
# https://github.com/msesia/knockoff-filter/blob/master/R/knockoff/R/create_second_order.R
205-
mu = X.mean(axis=0)
206-
sigma = cov_estimator.fit(X).covariance_
207-
208196
# Create knockoff variables
209-
X_tilde, mu_tilde, sigma_tilde_decompose = memory.cache(
210-
gaussian_knockoff_generation
211-
)(X, mu, sigma, seed=seed_list[0], tol=tol_gauss)
212-
213-
if n_bootstraps == 1:
214-
X_tildes = [X_tilde]
215-
else:
216-
X_tildes = parallel(
217-
delayed(repeat_gaussian_knockoff_generation)(
218-
mu_tilde,
219-
sigma_tilde_decompose,
220-
seed=seed,
221-
)
222-
for seed in seed_list[1:]
223-
)
224-
X_tildes.insert(0, X_tilde)
197+
conditionnal_sampler = GaussianDistribution(
198+
cov_estimator, random_state=seed_list[0], centered=centered, tol=tol_gauss
199+
)
200+
conditionnal_sampler.fit(X)
201+
X_tildes = [conditionnal_sampler.sample() for i in range(n_bootstraps)]
225202

226203
results = parallel(
227204
delayed(memory.cache(_stat_coefficient_diff))(

0 commit comments

Comments
 (0)