Skip to content

Commit 9f3fb74

Browse files
author
dPys
committed
[FIX] strict provenance of label intensities in parcellation image
1 parent 5b62ac0 commit 9f3fb74

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

pynets/fmri/estimation.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
def get_optimal_cov_estimator(time_series):
1616
from sklearn.covariance import GraphicalLassoCV
1717

18-
estimator_shrunk = None
1918
estimator = GraphicalLassoCV(cv=5, assume_centered=True)
2019
print("\nSearching for best Lasso estimator...\n")
2120
try:
2221
estimator.fit(time_series)
22+
return estimator
2323
except BaseException:
2424
ix = 0
2525
print("\nModel did not converge on first attempt. "
@@ -32,6 +32,7 @@ def get_optimal_cov_estimator(time_series):
3232
assume_centered=True)
3333
try:
3434
estimator.fit(time_series)
35+
return estimator
3536
except BaseException:
3637
ix += 1
3738
continue
@@ -42,7 +43,6 @@ def get_optimal_cov_estimator(time_series):
4243
"Unstable Lasso estimation. Applying shrinkage to empirical "
4344
"covariance..."
4445
)
45-
estimator = None
4646
from sklearn.covariance import (
4747
GraphicalLasso,
4848
empirical_covariance,
@@ -60,15 +60,13 @@ def get_optimal_cov_estimator(time_series):
6060
assume_centered=True)
6161
try:
6262
estimator_shrunk.fit(shrunk_cov)
63+
return estimator_shrunk
6364
except BaseException:
6465
continue
6566
except BaseException:
66-
return estimator
67-
68-
if estimator is None and estimator_shrunk is not None:
69-
estimator = estimator_shrunk
70-
71-
return estimator
67+
return None
68+
else:
69+
return estimator
7270

7371

7472
def get_conn_matrix(
@@ -278,15 +276,15 @@ def fallback_covariance(time_series):
278276
" valid estimator using the -mod flag.")
279277

280278
# Try with the best-fitting Lasso estimator
281-
if estimator is not None:
279+
if estimator:
282280
conn_measure = ConnectivityMeasure(cov_estimator=estimator,
283281
kind=kind)
284282
try:
285283
conn_matrix = conn_measure.fit_transform([time_series])[0]
286284
except (np.linalg.linalg.LinAlgError, FloatingPointError):
287-
fallback_covariance(time_series)
285+
conn_matrix = fallback_covariance(time_series)
288286
else:
289-
fallback_covariance(time_series)
287+
conn_matrix = fallback_covariance(time_series)
290288
else:
291289
if conn_model == "QuicGraphicalLasso":
292290
try:

0 commit comments

Comments
 (0)