1515def get_optimal_cov_estimator (time_series ):
1616 from sklearn .covariance import GraphicalLassoCV
1717
18- estimator_shrunk = None
1918 estimator = GraphicalLassoCV (cv = 5 , assume_centered = True )
2019 print ("\n Searching for best Lasso estimator...\n " )
2120 try :
2221 estimator .fit (time_series )
22+ return estimator
2323 except BaseException :
2424 ix = 0
2525 print ("\n Model did not converge on first attempt. "
@@ -32,6 +32,7 @@ def get_optimal_cov_estimator(time_series):
3232 assume_centered = True )
3333 try :
3434 estimator .fit (time_series )
35+ return estimator
3536 except BaseException :
3637 ix += 1
3738 continue
@@ -42,7 +43,6 @@ def get_optimal_cov_estimator(time_series):
4243 "Unstable Lasso estimation. Applying shrinkage to empirical "
4344 "covariance..."
4445 )
45- estimator = None
4646 from sklearn .covariance import (
4747 GraphicalLasso ,
4848 empirical_covariance ,
@@ -60,15 +60,13 @@ def get_optimal_cov_estimator(time_series):
6060 assume_centered = True )
6161 try :
6262 estimator_shrunk .fit (shrunk_cov )
63+ return estimator_shrunk
6364 except BaseException :
6465 continue
6566 except BaseException :
66- return estimator
67-
68- if estimator is None and estimator_shrunk is not None :
69- estimator = estimator_shrunk
70-
71- return estimator
67+ return None
68+ else :
69+ return estimator
7270
7371
7472def get_conn_matrix (
@@ -278,15 +276,15 @@ def fallback_covariance(time_series):
278276 " valid estimator using the -mod flag." )
279277
280278 # Try with the best-fitting Lasso estimator
281- if estimator is not None :
279+ if estimator :
282280 conn_measure = ConnectivityMeasure (cov_estimator = estimator ,
283281 kind = kind )
284282 try :
285283 conn_matrix = conn_measure .fit_transform ([time_series ])[0 ]
286284 except (np .linalg .linalg .LinAlgError , FloatingPointError ):
287- fallback_covariance (time_series )
285+ conn_matrix = fallback_covariance (time_series )
288286 else :
289- fallback_covariance (time_series )
287+ conn_matrix = fallback_covariance (time_series )
290288 else :
291289 if conn_model == "QuicGraphicalLasso" :
292290 try :
0 commit comments