Skip to content

Commit 56e107a

Browse files
lionelkuschjpaillardbthirion
authored
Reformat the gaussian conditional sampler (#366)
* gaussian distribution reformat * fix knockoff * move conditional sampler in statistical tools * move test from knockoff * move test in right folder * remove gaussian * remove unsed import * fix center to knockoff * Change name of GaussianDistribution * remove parameter assumed_centerd of tests * change the way of using the randomgenerator * Apply suggestions from code review Co-authored-by: Joseph Paillard <[email protected]> * reformat docstring * Update test/statistical_tools/test_gaussian_knockoffs.py Co-authored-by: Joseph Paillard <[email protected]> * improve tests * Apply suggestions from code review Co-authored-by: bthirion <[email protected]> * add assert * fix tests * Improve test * remove not necessary test * Update the conditioal sampler for new fit signature * fix test * fix isort * fix examples * fix documentation * fix documentation * update the user guide. * try to fix the end of file * Apply suggestion from @bthirion Co-authored-by: bthirion <[email protected]> * update makefile * Add comment [skip tests] * update command * [skip tests] * fix makefile * [skip tests] * change signature of sample * remove fix * remove documentation * Update src/hidimstat/statistical_tools/gaussian_knockoffs.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/statistical_tools/gaussian_knockoffs.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/statistical_tools/gaussian_knockoffs.py Co-authored-by: Joseph Paillard <[email protected]> * Update src/hidimstat/statistical_tools/gaussian_knockoffs.py Co-authored-by: Joseph Paillard <[email protected]> * format file * add test with repeat --------- Co-authored-by: Joseph Paillard <[email protected]> Co-authored-by: bthirion <[email protected]>
1 parent 319b433 commit 56e107a

File tree

11 files changed

+351
-293
lines changed

11 files changed

+351
-293
lines changed

docs/src/api.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ Samplers
5252
:toctree: ./generated/api/class/
5353
:template: class.rst
5454

55-
conditional_sampling.ConditionalSampler
55+
statistical_tools.conditional_sampling.ConditionalSampler
56+
statistical_tools.gaussian_knockoffs.GaussianKnockoffs
5657

5758
Helper Functions
5859
================

docs/src/user_guide.rst

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ Table of contents
2121
visualization.rst
2222
grouping.rst
2323
high_dimension.rst
24-

examples/plot_pitfalls_permutation_importance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@
267267

268268
from matplotlib.lines import Line2D
269269

270-
from hidimstat.conditional_sampling import ConditionalSampler
270+
from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler
271271

272272
X_train, X_test = train_test_split(
273273
X,

src/hidimstat/conditional_feature_importance.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from sklearn.metrics import root_mean_squared_error
55

66
from hidimstat.base_perturbation import BasePerturbation
7-
from hidimstat.conditional_sampling import ConditionalSampler
7+
from hidimstat.statistical_tools.conditional_sampling import ConditionalSampler
88

99

1010
class CFI(BasePerturbation):

src/hidimstat/gaussian_knockoff.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

src/hidimstat/knockoffs.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,8 @@
77
from sklearn.preprocessing import StandardScaler
88
from sklearn.utils.validation import check_memory
99

10-
from hidimstat._utils.utils import check_random_state
11-
from hidimstat.gaussian_knockoff import (
12-
gaussian_knockoff_generation,
13-
repeat_gaussian_knockoff_generation,
14-
)
1510
from hidimstat.statistical_tools.aggregation import quantile_aggregation
11+
from hidimstat.statistical_tools.gaussian_knockoffs import GaussianKnockoffs
1612
from hidimstat.statistical_tools.multiple_testing import fdr_threshold
1713

1814

@@ -188,36 +184,15 @@ def model_x_knockoff(
188184
n_jobs = min(n_bootstraps, n_jobs)
189185
parallel = Parallel(n_jobs, verbose=joblib_verbose)
190186

191-
# get the seed for the different run
192-
rng = check_random_state(random_state)
193-
children_rng = rng.spawn(n_bootstraps)
194-
195187
if centered:
196188
X = StandardScaler().fit_transform(X)
197189

198-
# estimation of X distribution
199-
# original implementation:
200-
# https://github.com/msesia/knockoff-filter/blob/master/R/knockoff/R/create_second_order.R
201-
mu = X.mean(axis=0)
202-
sigma = cov_estimator.fit(X).covariance_
203-
204190
# Create knockoff variables
205-
X_tilde, mu_tilde, sigma_tilde_decompose = memory.cache(
206-
gaussian_knockoff_generation
207-
)(X, mu, sigma, random_state=children_rng[0], tol=tol_gauss)
208-
209-
if n_bootstraps == 1:
210-
X_tildes = [X_tilde]
211-
else:
212-
X_tildes = parallel(
213-
delayed(repeat_gaussian_knockoff_generation)(
214-
mu_tilde,
215-
sigma_tilde_decompose,
216-
random_state=seed,
217-
)
218-
for seed in children_rng[1:]
219-
)
220-
X_tildes.insert(0, X_tilde)
191+
conditionnal_sampler = GaussianKnockoffs(cov_estimator, tol=tol_gauss)
192+
conditionnal_sampler.fit(X)
193+
X_tildes = conditionnal_sampler.sample(
194+
n_repeats=n_bootstraps, random_state=random_state
195+
)
221196

222197
results = parallel(
223198
delayed(memory.cache(_stat_coefficient_diff))(

0 commit comments

Comments
 (0)