@@ -16,8 +16,8 @@ def get_optimal_cov_estimator(time_series):
1616 from sklearn .covariance import GraphicalLassoCV
1717
1818 estimator_shrunk = None
19- estimator = GraphicalLassoCV (cv = 5 )
20- print ("\n Finding best estimator...\n " )
19+ estimator = GraphicalLassoCV (cv = 5 , assume_centered = True )
20+ print ("\n Searching for best Lasso estimator...\n " )
2121 try :
2222 estimator .fit (time_series )
2323 except BaseException :
@@ -27,8 +27,9 @@ def get_optimal_cov_estimator(time_series):
2727 while not hasattr (estimator , 'covariance_' ) and \
2828 not hasattr (estimator , 'precision_' ) and ix < 3 :
2929 for tol in [0.1 , 0.01 , 0.001 , 0.0001 ]:
30- print (tol )
31- estimator = GraphicalLassoCV (cv = 5 , max_iter = 200 , tol = tol )
30+ print (f"Auto-tuning Tolerance={ tol } " )
31+ estimator = GraphicalLassoCV (cv = 5 , max_iter = 200 , tol = tol ,
32+ assume_centered = True )
3233 try :
3334 estimator .fit (time_series )
3435 except BaseException :
@@ -38,49 +39,33 @@ def get_optimal_cov_estimator(time_series):
3839 if not hasattr (estimator , 'covariance_' ) and not hasattr (estimator ,
3940 'precision_' ):
4041 print (
41- "Unstable Lasso estimation. Applying shrinkage..."
42+ "Unstable Lasso estimation. Applying shrinkage to empirical "
43+ "covariance..."
44+ )
45+ estimator = None
46+ from sklearn .covariance import (
47+ GraphicalLasso ,
48+ empirical_covariance ,
49+ shrunk_covariance ,
4250 )
4351 try :
44- estimator = None
45- from sklearn .covariance import (
46- GraphicalLasso ,
47- empirical_covariance ,
48- shrunk_covariance ,
49- )
50-
51- emp_cov = empirical_covariance (time_series )
52- # Iterate across different levels of alpha
52+ emp_cov = empirical_covariance (time_series , assume_centered = True )
5353 for i in np .arange (0.8 , 0.99 , 0.01 ):
54+ print (f"Shrinkage={ i } :" )
5455 shrunk_cov = shrunk_covariance (emp_cov , shrinkage = i )
5556 alphaRange = 10.0 ** np .arange (- 8 , 0 )
5657 for alpha in alphaRange :
58+ print (f"Auto-tuning alpha={ alpha } ..." )
59+ estimator_shrunk = GraphicalLasso (alpha ,
60+ assume_centered = True )
5761 try :
58- estimator_shrunk = GraphicalLasso (alpha )
5962 estimator_shrunk .fit (shrunk_cov )
60- print (
61- f"Retrying covariance matrix estimate with"
62- f" alpha={ alpha } "
63- )
64- if estimator_shrunk is None :
65- pass
66- else :
67- break
6863 except BaseException :
69- print (
70- f"Covariance estimation failed with shrinkage"
71- f" at alpha={ alpha } "
72- )
7364 continue
74- except ValueError :
75- estimator = None
76- print (
77- "Covariance estimation failed. Check time-series data "
78- "for errors."
79- )
80- if estimator is None and estimator_shrunk is None :
81- raise RuntimeError ("\n ERROR: Covariance estimation failed." )
65+ except BaseException :
66+ return estimator
8267
83- if estimator is None :
68+ if estimator is None and estimator_shrunk is not None :
8469 estimator = estimator_shrunk
8570
8671 return estimator
@@ -229,7 +214,6 @@ def get_conn_matrix(
229214 for Gaussian and related Graphical Models. doi:10.5281/zenodo.830033
230215
231216 """
232- import sys
233217 from pynets .fmri .estimation import get_optimal_cov_estimator
234218 from nilearn .connectome import ConnectivityMeasure
235219
@@ -241,6 +225,40 @@ def get_conn_matrix(
241225 conn_matrix = None
242226 estimator = get_optimal_cov_estimator (time_series )
243227
228+ def fallback_covariance (time_series ):
229+ from sklearn .ensemble import IsolationForest
230+ from sklearn import covariance
231+
232+ # Remove gross outliers
233+ model = IsolationForest (contamination = 0.02 )
234+ model .fit (time_series )
235+ outlier_mask = model .predict (time_series )
236+ outlier_mask [outlier_mask == - 1 ] = 0
237+ time_series = time_series [outlier_mask .astype ('bool' )]
238+
239+ # Fall back to LedoitWolf
240+ print ('Matrix estimation failed with Lasso and shrinkage due to '
241+ 'ill conditions. Removing potential anomalies from the '
242+ 'time-series using IsolationForest...' )
243+ try :
244+ print ("Trying Ledoit-Wolf Estimator..." )
245+ conn_measure = ConnectivityMeasure (
246+ cov_estimator = covariance .LedoitWolf (store_precision = True ,
247+ assume_centered = True ),
248+ kind = kind )
249+ conn_matrix = conn_measure .fit_transform ([time_series ])[0 ]
250+ except (np .linalg .linalg .LinAlgError , FloatingPointError ):
251+ print ("Trying Oracle Approximating Shrinkage Estimator..." )
252+ conn_measure = ConnectivityMeasure (
253+ cov_estimator = covariance .OAS (assume_centered = True ),
254+ kind = kind )
255+ try :
256+ conn_matrix = conn_measure .fit_transform ([time_series ])[0 ]
257+ except (np .linalg .linalg .LinAlgError , FloatingPointError ):
258+ raise ValueError ('All covariance estimators failed to '
259+ 'converge...' )
260+ return conn_matrix
261+
244262 if conn_model in nilearn_kinds :
245263 if conn_model == "corr" or conn_model == "cor" or conn_model == "correlation" :
246264 print ("\n Computing correlation matrix...\n " )
@@ -259,32 +277,16 @@ def get_conn_matrix(
259277 "\n ERROR! No connectivity model specified at runtime. Select a"
260278 " valid estimator using the -mod flag." )
261279
262- try :
263- # Try with the best-fitting Lasso estimator
280+ # Try with the best-fitting Lasso estimator
281+ if estimator is not None :
264282 conn_measure = ConnectivityMeasure (cov_estimator = estimator ,
265283 kind = kind )
266- conn_matrix = conn_measure .fit_transform ([time_series ])[0 ]
267- except BaseException :
268- from sklearn .ensemble import IsolationForest
269-
270- # Remove gross outliers
271- model = IsolationForest (contamination = 0.02 )
272- model .fit (time_series )
273- outlier_mask = model .predict (time_series )
274- outlier_mask [outlier_mask == - 1 ] = 0
275- time_series = time_series [outlier_mask .astype ('bool' )]
276-
277- # Fall back to LedoitWolf
278- print ('Matrix estimation failed with Lasso and shrinkage due to '
279- 'ill conditions. Removing potential anomalies from the '
280- 'time-series using IsolationForest and falling back to '
281- 'LedoitWolf...' )
282284 try :
283- conn_measure = ConnectivityMeasure (kind = kind )
284285 conn_matrix = conn_measure .fit_transform ([time_series ])[0 ]
285- except RuntimeError :
286- print ('Matrix estimation failed.' )
287- sys .exit (1 )
286+ except (np .linalg .linalg .LinAlgError , FloatingPointError ):
287+ fallback_covariance (time_series )
288+ else :
289+ fallback_covariance (time_series )
288290 else :
289291 if conn_model == "QuicGraphicalLasso" :
290292 try :
0 commit comments