Skip to content

Commit b576522

Browse files
committed
fix knockoff
1 parent 190300d commit b576522

File tree

1 file changed

+6
-29
lines changed

1 file changed

+6
-29
lines changed

src/hidimstat/knockoffs.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,7 @@
77
from sklearn.utils import check_random_state
88
from sklearn.utils.validation import check_memory
99

10-
from hidimstat.gaussian_knockoff import (
11-
gaussian_knockoff_generation,
12-
repeat_gaussian_knockoff_generation,
13-
)
10+
from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution
1411
from hidimstat.statistical_tools.multiple_testing import fdr_threshold
1512
from hidimstat.statistical_tools.aggregation import quantile_aggregation
1613

@@ -196,32 +193,12 @@ def model_x_knockoff(
196193
raise TypeError("Wrong type for random_state")
197194
seed_list = rng.randint(1, np.iinfo(np.int32).max, n_bootstraps)
198195

199-
if centered:
200-
X = StandardScaler().fit_transform(X)
201-
202-
# estimation of X distribution
203-
# original implementation:
204-
# https://github.com/msesia/knockoff-filter/blob/master/R/knockoff/R/create_second_order.R
205-
mu = X.mean(axis=0)
206-
sigma = cov_estimator.fit(X).covariance_
207-
208196
# Create knockoff variables
209-
X_tilde, mu_tilde, sigma_tilde_decompose = memory.cache(
210-
gaussian_knockoff_generation
211-
)(X, mu, sigma, seed=seed_list[0], tol=tol_gauss)
212-
213-
if n_bootstraps == 1:
214-
X_tildes = [X_tilde]
215-
else:
216-
X_tildes = parallel(
217-
delayed(repeat_gaussian_knockoff_generation)(
218-
mu_tilde,
219-
sigma_tilde_decompose,
220-
seed=seed,
221-
)
222-
for seed in seed_list[1:]
223-
)
224-
X_tildes.insert(0, X_tilde)
197+
conditionnal_sampler = GaussianDistribution(
198+
cov_estimator, random_state=seed_list[0], centered=centered, tol=tol_gauss
199+
)
200+
conditionnal_sampler.fit(X)
201+
X_tildes = [conditionnal_sampler.sample() for i in range(n_bootstraps)]
225202

226203
results = parallel(
227204
delayed(memory.cache(_stat_coefficient_diff))(

0 commit comments

Comments
 (0)