Skip to content

Commit 1f97f5b

Browse files
lionelkuschbthirionjpaillard
authored
[API 2]: Desparsified Lasso (#381)
* update Desparsified Lasso * fix desparsified lasso and the example * fix example * add tests * update docstring * change signal noise ratio * change name for model * smal modification * fix name variable * fix docstring * fix example and test * fix commit * rename variable * fix format * fix example * update docstring * add warning * add warning * fix docstring * fix docstring * add options alphas * update comments * remove shuffle in cv * Fix comment * change cv in EnClDL * improve coverage * Update src/hidimstat/noise_std.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * replace n_time by n_task * replace n_time by n_task * Move reid in desparsified lasso * fix import * fix import * fix definition of the covariance * add an exception * fix import order * format * change default value of the n_job * fix bg * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/ensemble_clustered_inference.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update src/hidimstat/desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Update test/test_desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * fix multitasklassocv change name * Update src/hidimstat/desparsified_lasso.py Co-authored-by: bthirion <[email protected]> * Add shape in docstring * Add new workflow for maintenance (#501) add new workflow * Pr ci maint (#502) * fix ci * fix condition for ci * add comment * Remove parallel generation of example (#497) * remove the generation of example in parallel * remove option memory [skip tests] * add the tack of memory * fix issue of memory load from example fmri * fix example * remove error modification * fix codespell * fix number of jobs * fix example * remove parallization of short example * [skip tests] * update examples * remove modification [skip tests] * improve plot fmri * fix example? * remove warnings * remove memory issue * pass Desparsified Lasso * fix tests * fix tests * add example * fix example * text in example * small rendering fix * fix 1- pval, place imports along the exemple * finish merge * remove digits example * add docstring * fix FDR in test, add randomness tests for DL --------- Co-authored-by: bthirion <[email protected]> Co-authored-by: Joseph Paillard <[email protected]>
1 parent f1cf295 commit 1f97f5b

File tree

12 files changed

+1468
-1233
lines changed

12 files changed

+1468
-1233
lines changed

docs/src/api.rst

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Feature Importance Classes
2929
PFI
3030
D0CRT
3131
ModelXKnockoff
32+
DesparsifiedLasso
3233

3334
Feature Importance functions
3435
============================
@@ -39,9 +40,6 @@ Feature Importance functions
3940

4041
clustered_inference
4142
clustered_inference_pvalue
42-
desparsified_lasso
43-
desparsified_lasso_pvalue
44-
desparsified_group_lasso_pvalue
4543
ensemble_clustered_inference
4644
ensemble_clustered_inference_pvalue
4745

examples/plot_2D_simulation_example.py

Lines changed: 53 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -46,22 +46,6 @@
4646
.. footbibliography::
4747
4848
"""
49-
import matplotlib.pyplot as plt
50-
import numpy as np
51-
from sklearn.cluster import FeatureAgglomeration
52-
from sklearn.feature_extraction import image
53-
from sklearn.preprocessing import StandardScaler
54-
55-
from hidimstat._utils.scenario import multivariate_simulation_spatial
56-
from hidimstat.desparsified_lasso import desparsified_lasso, desparsified_lasso_pvalue
57-
from hidimstat.ensemble_clustered_inference import (
58-
clustered_inference,
59-
clustered_inference_pvalue,
60-
ensemble_clustered_inference,
61-
ensemble_clustered_inference_pvalue,
62-
)
63-
from hidimstat.statistical_tools.p_values import zscore_from_pval
64-
6549
# %%
6650
# Generating the data
6751
# -------------------
@@ -71,6 +55,9 @@
7155
# example.
7256

7357
# simulation parameters
58+
59+
from hidimstat._utils.scenario import multivariate_simulation_spatial
60+
7461
n_samples = 100
7562
shape = (40, 40)
7663
n_features = shape[1] * shape[0]
@@ -83,6 +70,7 @@
8370
n_samples, shape, roi_size, signal_noise_ratio, smooth_X, seed=0
8471
)
8572

73+
8674
# %%
8775
# Choosing inference parameters
8876
# -----------------------------
@@ -110,7 +98,8 @@
11098
delta = 6
11199

112100
# number of worker
113-
n_jobs = 3
101+
n_jobs = 4
102+
114103

115104
# %%
116105
# Computing z-score thresholds for support estimation
@@ -125,12 +114,15 @@
125114
# consists in dividing by the number of clusters.
126115

127116

117+
from hidimstat.statistical_tools.p_values import zscore_from_pval
118+
128119
# computing the z-score thresholds for feature selection
129120
correction_no_cluster = 1.0 / n_features
130121
correction_cluster = 1.0 / n_clusters
131122
thr_c = zscore_from_pval((fwer_target / 2) * correction_cluster)
132123
thr_nc = zscore_from_pval((fwer_target / 2) * correction_no_cluster)
133124

125+
134126
# %%
135127
# Inference with several algorithms
136128
# ---------------------------------
@@ -139,6 +131,9 @@
139131
# the theoretical tolerance region.
140132

141133

134+
import numpy as np
135+
136+
142137
# The following function builds a 2D map with four active regions that are
143138
# enfolded by thin tolerance regions.
144139
def weight_map_2D_extended(shape, roi_size, delta):
@@ -174,6 +169,7 @@ def weight_map_2D_extended(shape, roi_size, delta):
174169
# compute true support with visible spatial tolerance
175170
beta_extended = weight_map_2D_extended(shape, roi_size, delta)
176171

172+
177173
# %%
178174
# Now, we compute the support estimated by a high-dimensional statistical
179175
# inference method that does not leverage the data structure.
@@ -183,56 +179,53 @@ def weight_map_2D_extended(shape, roi_size, delta):
183179
# and referred to as Desparsified Lasso.
184180

185181

182+
from hidimstat import DesparsifiedLasso
183+
186184
# compute desparsified lasso
187-
beta_hat, sigma_hat, precision_diagonal = desparsified_lasso(
188-
X_init,
189-
y,
190-
n_jobs=n_jobs,
191-
random_state=0,
192-
)
193-
pval, pval_corr, one_minus_pval, one_minus_pval_corr, cb_min, cb_max = (
194-
desparsified_lasso_pvalue(
195-
X_init.shape[0],
196-
beta_hat,
197-
sigma_hat,
198-
precision_diagonal,
199-
)
200-
)
185+
desparsified_lasso = DesparsifiedLasso(n_jobs=n_jobs, random_state=0)
186+
desparsified_lasso.fit_importance(X_init, y)
201187

202188
# compute estimated support (first method)
203-
zscore = zscore_from_pval(pval, one_minus_pval)
189+
zscore = zscore_from_pval(
190+
desparsified_lasso.pvalues_, desparsified_lasso.one_minus_pvalues_
191+
)
204192
selected_dl = zscore > thr_nc # use the "no clustering threshold"
205193

206194
# compute estimated support (second method)
207195
selected_dl = np.logical_or(
208-
pval_corr < fwer_target / 2, one_minus_pval_corr < fwer_target / 2
196+
desparsified_lasso.pvalues_corr_ < fwer_target / 2,
197+
desparsified_lasso.one_minus_pvalues_corr_ < fwer_target / 2,
209198
)
210199

200+
211201
# %%
212202
# Now, we compute the support estimated using a clustered inference algorithm
213203
# (c.f. :footcite:t:`chevalier2022spatially`) called Clustered Desparsified Lasso
214204
# (CluDL) since it uses the Desparsified Lasso technique after clustering the data.
215205

216206
# Define the FeatureAgglomeration object that performs the clustering.
217207
# This object is necessary to run the current algorithm and the following one.
208+
209+
from sklearn.cluster import FeatureAgglomeration
210+
from sklearn.feature_extraction import image
211+
from sklearn.preprocessing import StandardScaler
212+
213+
from hidimstat.ensemble_clustered_inference import (
214+
clustered_inference,
215+
clustered_inference_pvalue,
216+
)
217+
218218
connectivity = image.grid_to_graph(n_x=shape[0], n_y=shape[1])
219219
ward = FeatureAgglomeration(
220220
n_clusters=n_clusters, connectivity=connectivity, linkage="ward"
221221
)
222222

223223
# clustered desparsified lasso (CluDL)
224-
ward_, beta_hat, theta_hat, omega_diag = clustered_inference(
224+
ward_, desparsified_lasso_ = clustered_inference(
225225
X_init, y, ward, scaler_sampling=StandardScaler(), random_state=0
226226
)
227227
beta_hat, pval, pval_corr, one_minus_pval, one_minus_pval_corr = (
228-
clustered_inference_pvalue(
229-
n_samples,
230-
False,
231-
ward_,
232-
beta_hat,
233-
theta_hat,
234-
omega_diag,
235-
)
228+
clustered_inference_pvalue(n_samples, False, ward_, desparsified_lasso_)
236229
)
237230

238231
# compute estimated support (first method)
@@ -244,40 +237,46 @@ def weight_map_2D_extended(shape, roi_size, delta):
244237
pval_corr < fwer_target / 2, one_minus_pval_corr < fwer_target / 2
245238
)
246239

240+
247241
# %%
248242
# Finally, we compute the support estimated by an ensembled clustered
249243
# inference algorithm (c.f. :footcite:t:`chevalier2022spatially`). This algorithm is called
250244
# Ensemble of Clustered Desparsified Lasso (EnCluDL) since it runs several
251245
# CluDL algorithms with different clustering choices. The different CluDL
252246
# solutions are then aggregated into one.
253247

248+
from hidimstat.ensemble_clustered_inference import (
249+
ensemble_clustered_inference,
250+
ensemble_clustered_inference_pvalue,
251+
)
252+
254253
# ensemble of clustered desparsified lasso (EnCluDL)
255-
list_ward, list_beta_hat, list_theta_hat, list_omega_diag = (
256-
ensemble_clustered_inference(
257-
X_init,
258-
y,
259-
ward,
260-
scaler_sampling=StandardScaler(),
261-
random_state=0,
262-
)
254+
list_ward, list_desparsified_lasso = ensemble_clustered_inference(
255+
X_init,
256+
y,
257+
ward,
258+
scaler_sampling=StandardScaler(),
259+
random_state=0,
260+
n_jobs=n_jobs,
263261
)
264262
beta_hat, selected_ecdl = ensemble_clustered_inference_pvalue(
265263
n_samples,
266264
False,
267265
list_ward,
268-
list_beta_hat,
269-
list_theta_hat,
270-
list_omega_diag,
266+
list_desparsified_lasso,
271267
fdr=fwer_target,
272268
)
273269

270+
274271
# %%
275272
# Results
276273
# -------
277274
#
278275
# Now we plot the true support, the theoretical tolerance regions and
279276
# the estimated supports for every method.
280277

278+
import matplotlib.pyplot as plt
279+
281280

282281
# To generate a plot that exhibits
283282
# the true support and the estimated supports for every method,
@@ -342,6 +341,7 @@ def plot(maps, titles):
342341

343342
plot(maps, titles)
344343

344+
345345
# %%
346346
# Analysis of the results
347347
# -----------------------

examples/plot_fmri_data_example.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,15 @@
3838
from nilearn.image import mean_img
3939
from nilearn.maskers import NiftiMasker
4040
from nilearn.plotting import plot_stat_map, show
41+
from sklearn.base import clone
4142
from sklearn.cluster import FeatureAgglomeration
4243
from sklearn.feature_extraction import image
44+
from sklearn.linear_model import LassoCV
45+
from sklearn.model_selection import KFold
4346
from sklearn.preprocessing import StandardScaler
4447
from sklearn.utils import Bunch
4548

46-
from hidimstat.desparsified_lasso import desparsified_lasso, desparsified_lasso_pvalue
49+
from hidimstat.desparsified_lasso import DesparsifiedLasso
4750
from hidimstat.ensemble_clustered_inference import (
4851
clustered_inference,
4952
clustered_inference_pvalue,
@@ -144,19 +147,34 @@ def preprocess_haxby(subject=2, memory=None):
144147
new_hard_limit = limit_5G if hard < 0 else min(limit_5G, hard)
145148
resource.setrlimit(resource.RLIMIT_AS, (new_soft_limit, new_hard_limit))
146149

150+
# Default estimator
151+
estimator = LassoCV(
152+
eps=1e-2,
153+
fit_intercept=False,
154+
cv=KFold(n_splits=5, shuffle=True, random_state=0),
155+
tol=1e-2,
156+
max_iter=6000,
157+
random_state=1,
158+
n_jobs=1,
159+
)
160+
161+
147162
# First, we try to recover the discriminative pattern by computing
148163
# p-values from desparsified lasso.
149164
# Due to the size of the X, it's not possible to use this method with a limit
150165
# of 5 G for memory. To handle this problem, the following methods use some
151166
# feature aggregation methods.
152167
#
153168
try:
154-
beta_hat, sigma_hat, precision_diagonal = desparsified_lasso(
155-
X, y, noise_method="median", max_iteration=1000, random_state=0, n_jobs=n_jobs
156-
)
157-
pval_dl, _, one_minus_pval_dl, _, cb_min, cb_max = desparsified_lasso_pvalue(
158-
X.shape[0], beta_hat, sigma_hat, precision_diagonal
169+
desparsified_lasso = DesparsifiedLasso(
170+
noise_method="median",
171+
estimator=clone(estimator),
172+
random_state=0,
173+
n_jobs=n_jobs,
159174
)
175+
desparsified_lasso.fit_importance(X, y)
176+
pval_dl = desparsified_lasso.pvalues_
177+
one_minus_pval_dl = desparsified_lasso.one_minus_pvalues_
160178
except MemoryError as err:
161179
pval_dl = None
162180
one_minus_pval_dl = None
@@ -165,17 +183,18 @@ def preprocess_haxby(subject=2, memory=None):
165183
# %%
166184
# Now, the clustered inference algorithm which combines parcellation
167185
# and high-dimensional inference (c.f. References).
168-
ward_, beta_hat, theta_hat, omega_diag = clustered_inference(
186+
ward_, cl_desparsified_lasso = clustered_inference(
169187
X,
170188
y,
171189
ward,
172190
scaler_sampling=StandardScaler(),
173-
tolerance=1e-2,
191+
estimator=clone(estimator),
192+
tolerance_reid=1e-2,
174193
random_state=1,
175194
n_jobs=n_jobs,
176195
)
177196
beta_hat, pval_cdl, _, one_minus_pval_cdl, _ = clustered_inference_pvalue(
178-
X.shape[0], None, ward_, beta_hat, theta_hat, omega_diag
197+
X.shape[0], None, ward_, cl_desparsified_lasso
179198
)
180199

181200
# %%
@@ -185,28 +204,24 @@ def preprocess_haxby(subject=2, memory=None):
185204
# which means that 5 different parcellations are considered and
186205
# then 5 statistical maps are produced and aggregated into one.
187206
# However you might benefit from clustering randomization taking
188-
# `n_bootstraps=25` or `n_bootstraps=100`, also we set `n_jobs`.
189-
list_ward, list_beta_hat, list_theta_hat, list_omega_diag = (
190-
ensemble_clustered_inference(
191-
X,
192-
y,
193-
ward,
194-
groups=groups,
195-
scaler_sampling=StandardScaler(),
196-
n_bootstraps=5,
197-
max_iteration=6000,
198-
tolerance=1e-2,
199-
random_state=2,
200-
n_jobs=n_jobs,
201-
)
207+
# `n_bootstraps=25` or `n_bootstraps=100`, also we set `n_jobs=n_jobs`.
208+
list_ward, list_cl_desparsified_lasso = ensemble_clustered_inference(
209+
X,
210+
y,
211+
ward,
212+
groups=groups,
213+
scaler_sampling=StandardScaler(),
214+
n_bootstraps=5,
215+
estimator=clone(estimator),
216+
tolerance_reid=1e-2,
217+
random_state=2,
218+
n_jobs=n_jobs,
202219
)
203220
beta_hat, selected = ensemble_clustered_inference_pvalue(
204221
X.shape[0],
205222
False,
206223
list_ward,
207-
list_beta_hat,
208-
list_theta_hat,
209-
list_omega_diag,
224+
list_cl_desparsified_lasso,
210225
fdr=0.1,
211226
)
212227

0 commit comments

Comments
 (0)