|
7 | 7 | from sklearn.utils import check_random_state |
8 | 8 | from sklearn.utils.validation import check_memory |
9 | 9 |
|
10 | | -from hidimstat.gaussian_knockoff import ( |
11 | | - gaussian_knockoff_generation, |
12 | | - repeat_gaussian_knockoff_generation, |
13 | | -) |
| 10 | +from hidimstat.statistical_tools.gaussian_distribution import GaussianDistribution |
14 | 11 | from hidimstat.statistical_tools.multiple_testing import fdr_threshold |
15 | 12 | from hidimstat.statistical_tools.aggregation import quantile_aggregation |
16 | 13 |
|
@@ -196,32 +193,12 @@ def model_x_knockoff( |
196 | 193 | raise TypeError("Wrong type for random_state") |
197 | 194 | seed_list = rng.randint(1, np.iinfo(np.int32).max, n_bootstraps) |
198 | 195 |
|
199 | | - if centered: |
200 | | - X = StandardScaler().fit_transform(X) |
201 | | - |
202 | | - # estimation of X distribution |
203 | | - # original implementation: |
204 | | - # https://github.com/msesia/knockoff-filter/blob/master/R/knockoff/R/create_second_order.R |
205 | | - mu = X.mean(axis=0) |
206 | | - sigma = cov_estimator.fit(X).covariance_ |
207 | | - |
208 | 196 | # Create knockoff variables |
209 | | - X_tilde, mu_tilde, sigma_tilde_decompose = memory.cache( |
210 | | - gaussian_knockoff_generation |
211 | | - )(X, mu, sigma, seed=seed_list[0], tol=tol_gauss) |
212 | | - |
213 | | - if n_bootstraps == 1: |
214 | | - X_tildes = [X_tilde] |
215 | | - else: |
216 | | - X_tildes = parallel( |
217 | | - delayed(repeat_gaussian_knockoff_generation)( |
218 | | - mu_tilde, |
219 | | - sigma_tilde_decompose, |
220 | | - seed=seed, |
221 | | - ) |
222 | | - for seed in seed_list[1:] |
223 | | - ) |
224 | | - X_tildes.insert(0, X_tilde) |
| 197 | + conditionnal_sampler = GaussianDistribution( |
| 198 | + cov_estimator, random_state=seed_list[0], centered=centered, tol=tol_gauss |
| 199 | + ) |
| 200 | + conditionnal_sampler.fit(X) |
| 201 | + X_tildes = [conditionnal_sampler.sample() for i in range(n_bootstraps)] |
225 | 202 |
|
226 | 203 | results = parallel( |
227 | 204 | delayed(memory.cache(_stat_coefficient_diff))( |
|
0 commit comments