Skip to content

Commit 5b62ac0

Browse files
author
dPys
committed
[FIX] strict provenance of label intensities in parcellation image
1 parent 75c07bc commit 5b62ac0

File tree

5 files changed

+114
-73
lines changed

5 files changed

+114
-73
lines changed

pynets/core/interfaces.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,10 @@ def _run_interface(self, runtime):
179179
"installed?")
180180
try:
181181
if self.inputs.clustering is False:
182-
[uatlas,
183-
_] = nodemaker.enforce_hem_distinct_consecutive_labels(
184-
uatlas)
182+
[uatlas, _] = \
183+
nodemaker.enforce_hem_distinct_consecutive_labels(
184+
uatlas)
185+
185186
# Fetch user-specified atlas coords
186187
[coords, _, par_max, label_intensities] = \
187188
nodemaker.get_names_and_coords_of_parcels(uatlas)

pynets/core/nodemaker.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_sphere(coords, r, vox_dims, dims):
5858
return neighbors
5959

6060

61-
def create_parcel_atlas(parcel_list):
61+
def create_parcel_atlas(parcel_list, label_intensities=None):
6262
"""
6363
Create a 3D Nifti1Image atlas parcellation of consecutive integer
6464
intensities from an input list of ROI's.
@@ -92,7 +92,11 @@ def create_parcel_atlas(parcel_list):
9292
parcel_list[0].shape,
9393
dtype=bool))] + parcel_list
9494
concatted_parcels = concat_imgs(parcel_list_exp, dtype=np.float32)
95-
parcel_list_exp = np.array(range(len(parcel_list_exp))).astype("float32")
95+
if label_intensities is not None:
96+
parcel_list_exp = np.array([0] + label_intensities).astype("float32")
97+
else:
98+
parcel_list_exp = np.array(range(len(parcel_list_exp))
99+
).astype("float32")
96100
parcel_sum = np.sum(
97101
parcel_list_exp *
98102
np.asarray(
@@ -693,11 +697,18 @@ def parcel_masker(
693697
" brain mask/roi..."
694698
)
695699

700+
if any(isinstance(sub, tuple) for sub in labels_adj):
701+
label_intensities = [i[1] for i in labels_adj]
702+
elif any(isinstance(sub, dict) for sub in labels_adj):
703+
label_intensities = None
704+
else:
705+
label_intensities = labels_adj
706+
696707
# Create a resampled 3D atlas that can be viewed alongside mask img for QA
697708
resampled_parcels_nii_path = f"{dir_path}/{ID}_parcels_resampled2roimask" \
698709
f"_{op.basename(roi).split('.')[0]}.nii.gz"
699710
resampled_parcels_map_nifti = resample_img(
700-
nodemaker.create_parcel_atlas(parcel_list_adj)[0],
711+
nodemaker.create_parcel_atlas(parcel_list_adj, label_intensities)[0],
701712
target_affine=mask_aff,
702713
target_shape=mask_data.shape,
703714
interpolation="nearest",
@@ -1479,8 +1490,16 @@ def node_gen_masking(
14791490
[coords, labels, parcel_list_masked] = nodemaker.parcel_masker(
14801491
roi, coords, parcel_list, labels, dir_path, ID, perc_overlap
14811492
)
1493+
1494+
if any(isinstance(sub, tuple) for sub in labels):
1495+
label_intensities = [i[1] for i in labels]
1496+
elif any(isinstance(sub, dict) for sub in labels):
1497+
label_intensities = None
1498+
else:
1499+
label_intensities = labels
1500+
14821501
[net_parcels_map_nifti, _] = nodemaker.create_parcel_atlas(
1483-
parcel_list_masked)
1502+
parcel_list_masked, label_intensities)
14841503

14851504
assert (
14861505
len(coords)
@@ -1549,7 +1568,15 @@ def node_gen(coords, parcel_list, labels, dir_path, ID, parc, atlas, uatlas):
15491568
parcel_list = [index_img(parcel_list_img, i) for i in
15501569
range(parcel_list_img.shape[-1])]
15511570

1552-
[net_parcels_map_nifti, _] = nodemaker.create_parcel_atlas(parcel_list)
1571+
if any(isinstance(sub, tuple) for sub in labels):
1572+
label_intensities = [i[1] for i in labels]
1573+
elif any(isinstance(sub, dict) for sub in labels):
1574+
label_intensities = None
1575+
else:
1576+
label_intensities = labels
1577+
1578+
[net_parcels_map_nifti, _] = nodemaker.create_parcel_atlas(parcel_list,
1579+
label_intensities)
15531580

15541581
coords = list(tuple(x) for x in coords)
15551582

pynets/dmri/track.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def track_ensemble(
437437
ix = 0
438438
while float(stream_counter) < float(target_samples) and float(ix) < \
439439
len(all_combs):
440-
with Parallel(n_jobs=nthreads, backend='loky', max_nbytes='8000M',
440+
with Parallel(n_jobs=nthreads, backend='loky',
441441
mmap_mode='r+', temp_folder=cache_dir,
442442
verbose=10) as parallel:
443443
out_streams = parallel(
@@ -641,6 +641,8 @@ def run_tracking(step_curv_combinations, atlas_data_wm_gm_int, recon_path,
641641

642642
del atlas_data
643643

644+
parcel_vec = list(np.ones(len(parcels)).astype("bool"))
645+
644646
with h5py.File(recon_path_tmp_path, 'r+') as hf:
645647
mod_fit = hf['reconstruction'][:]
646648
hf.close()
@@ -745,7 +747,7 @@ def run_tracking(step_curv_combinations, atlas_data_wm_gm_int, recon_path,
745747
roi_proximal_streamlines,
746748
affine=np.eye(4),
747749
rois=parcels,
748-
include=list(np.ones(len(parcels)).astype("bool")),
750+
include=parcel_vec,
749751
mode="%s" % ("any" if waymask is not None else
750752
"either_end"),
751753
tol=roi_neighborhood_tol,

pynets/fmri/estimation.py

Lines changed: 61 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def get_optimal_cov_estimator(time_series):
1616
from sklearn.covariance import GraphicalLassoCV
1717

1818
estimator_shrunk = None
19-
estimator = GraphicalLassoCV(cv=5)
20-
print("\nFinding best estimator...\n")
19+
estimator = GraphicalLassoCV(cv=5, assume_centered=True)
20+
print("\nSearching for best Lasso estimator...\n")
2121
try:
2222
estimator.fit(time_series)
2323
except BaseException:
@@ -27,8 +27,9 @@ def get_optimal_cov_estimator(time_series):
2727
while not hasattr(estimator, 'covariance_') and \
2828
not hasattr(estimator, 'precision_') and ix < 3:
2929
for tol in [0.1, 0.01, 0.001, 0.0001]:
30-
print(tol)
31-
estimator = GraphicalLassoCV(cv=5, max_iter=200, tol=tol)
30+
print(f"Auto-tuning Tolerance={tol}")
31+
estimator = GraphicalLassoCV(cv=5, max_iter=200, tol=tol,
32+
assume_centered=True)
3233
try:
3334
estimator.fit(time_series)
3435
except BaseException:
@@ -38,49 +39,33 @@ def get_optimal_cov_estimator(time_series):
3839
if not hasattr(estimator, 'covariance_') and not hasattr(estimator,
3940
'precision_'):
4041
print(
41-
"Unstable Lasso estimation. Applying shrinkage..."
42+
"Unstable Lasso estimation. Applying shrinkage to empirical "
43+
"covariance..."
44+
)
45+
estimator = None
46+
from sklearn.covariance import (
47+
GraphicalLasso,
48+
empirical_covariance,
49+
shrunk_covariance,
4250
)
4351
try:
44-
estimator = None
45-
from sklearn.covariance import (
46-
GraphicalLasso,
47-
empirical_covariance,
48-
shrunk_covariance,
49-
)
50-
51-
emp_cov = empirical_covariance(time_series)
52-
# Iterate across different levels of alpha
52+
emp_cov = empirical_covariance(time_series, assume_centered=True)
5353
for i in np.arange(0.8, 0.99, 0.01):
54+
print(f"Shrinkage={i}:")
5455
shrunk_cov = shrunk_covariance(emp_cov, shrinkage=i)
5556
alphaRange = 10.0 ** np.arange(-8, 0)
5657
for alpha in alphaRange:
58+
print(f"Auto-tuning alpha={alpha}...")
59+
estimator_shrunk = GraphicalLasso(alpha,
60+
assume_centered=True)
5761
try:
58-
estimator_shrunk = GraphicalLasso(alpha)
5962
estimator_shrunk.fit(shrunk_cov)
60-
print(
61-
f"Retrying covariance matrix estimate with"
62-
f" alpha={alpha}"
63-
)
64-
if estimator_shrunk is None:
65-
pass
66-
else:
67-
break
6863
except BaseException:
69-
print(
70-
f"Covariance estimation failed with shrinkage"
71-
f" at alpha={alpha}"
72-
)
7364
continue
74-
except ValueError:
75-
estimator = None
76-
print(
77-
"Covariance estimation failed. Check time-series data "
78-
"for errors."
79-
)
80-
if estimator is None and estimator_shrunk is None:
81-
raise RuntimeError("\nERROR: Covariance estimation failed.")
65+
except BaseException:
66+
return estimator
8267

83-
if estimator is None:
68+
if estimator is None and estimator_shrunk is not None:
8469
estimator = estimator_shrunk
8570

8671
return estimator
@@ -229,7 +214,6 @@ def get_conn_matrix(
229214
for Gaussian and related Graphical Models. doi:10.5281/zenodo.830033
230215
231216
"""
232-
import sys
233217
from pynets.fmri.estimation import get_optimal_cov_estimator
234218
from nilearn.connectome import ConnectivityMeasure
235219

@@ -241,6 +225,40 @@ def get_conn_matrix(
241225
conn_matrix = None
242226
estimator = get_optimal_cov_estimator(time_series)
243227

228+
def fallback_covariance(time_series):
229+
from sklearn.ensemble import IsolationForest
230+
from sklearn import covariance
231+
232+
# Remove gross outliers
233+
model = IsolationForest(contamination=0.02)
234+
model.fit(time_series)
235+
outlier_mask = model.predict(time_series)
236+
outlier_mask[outlier_mask == -1] = 0
237+
time_series = time_series[outlier_mask.astype('bool')]
238+
239+
# Fall back to LedoitWolf
240+
print('Matrix estimation failed with Lasso and shrinkage due to '
241+
'ill conditions. Removing potential anomalies from the '
242+
'time-series using IsolationForest...')
243+
try:
244+
print("Trying Ledoit-Wolf Estimator...")
245+
conn_measure = ConnectivityMeasure(
246+
cov_estimator=covariance.LedoitWolf(store_precision=True,
247+
assume_centered=True),
248+
kind=kind)
249+
conn_matrix = conn_measure.fit_transform([time_series])[0]
250+
except (np.linalg.linalg.LinAlgError, FloatingPointError):
251+
print("Trying Oracle Approximating Shrinkage Estimator...")
252+
conn_measure = ConnectivityMeasure(
253+
cov_estimator=covariance.OAS(assume_centered=True),
254+
kind=kind)
255+
try:
256+
conn_matrix = conn_measure.fit_transform([time_series])[0]
257+
except (np.linalg.linalg.LinAlgError, FloatingPointError):
258+
raise ValueError('All covariance estimators failed to '
259+
'converge...')
260+
return conn_matrix
261+
244262
if conn_model in nilearn_kinds:
245263
if conn_model == "corr" or conn_model == "cor" or conn_model == "correlation":
246264
print("\nComputing correlation matrix...\n")
@@ -259,32 +277,16 @@ def get_conn_matrix(
259277
"\nERROR! No connectivity model specified at runtime. Select a"
260278
" valid estimator using the -mod flag.")
261279

262-
try:
263-
# Try with the best-fitting Lasso estimator
280+
# Try with the best-fitting Lasso estimator
281+
if estimator is not None:
264282
conn_measure = ConnectivityMeasure(cov_estimator=estimator,
265283
kind=kind)
266-
conn_matrix = conn_measure.fit_transform([time_series])[0]
267-
except BaseException:
268-
from sklearn.ensemble import IsolationForest
269-
270-
# Remove gross outliers
271-
model = IsolationForest(contamination=0.02)
272-
model.fit(time_series)
273-
outlier_mask = model.predict(time_series)
274-
outlier_mask[outlier_mask == -1] = 0
275-
time_series = time_series[outlier_mask.astype('bool')]
276-
277-
# Fall back to LedoitWolf
278-
print('Matrix estimation failed with Lasso and shrinkage due to '
279-
'ill conditions. Removing potential anomalies from the '
280-
'time-series using IsolationForest and falling back to '
281-
'LedoitWolf...')
282284
try:
283-
conn_measure = ConnectivityMeasure(kind=kind)
284285
conn_matrix = conn_measure.fit_transform([time_series])[0]
285-
except RuntimeError:
286-
print('Matrix estimation failed.')
287-
sys.exit(1)
286+
except (np.linalg.linalg.LinAlgError, FloatingPointError):
287+
fallback_covariance(time_series)
288+
else:
289+
fallback_covariance(time_series)
288290
else:
289291
if conn_model == "QuicGraphicalLasso":
290292
try:

tests/test_track.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import indexed_gzip
1616
import numpy as np
1717
import logging
18+
import h5py
1819

1920
logger = logging.getLogger(__name__)
2021
logger.setLevel(50)
@@ -148,9 +149,13 @@ def test_track_ensemble(directget, target_samples):
148149
dwi_data = dwi_img.get_fdata()
149150

150151
temp_dir = tempfile.TemporaryDirectory()
151-
recon_path = temp_dir.name + '/model_file.npy'
152+
recon_path = temp_dir.name + '/model_file.hdf5'
152153
model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi)
153-
np.save(recon_path, model)
154+
155+
with h5py.File(recon_path, 'w') as hf:
156+
hf.create_dataset("reconstruction",
157+
data=model.astype('float32'))
158+
hf.close()
154159

155160
streamlines = track.track_ensemble(target_samples, atlas_data_wm_gm_int,
156161
labels_im_file,
@@ -203,8 +208,12 @@ def test_track_ensemble_particle():
203208

204209
model, _ = track.reconstruction(conn_model, gtab, dwi_data, wm_in_dwi)
205210
temp_dir = tempfile.TemporaryDirectory()
206-
recon_path = temp_dir.name + '/model_file.npy'
207-
np.save(recon_path, model)
211+
recon_path = temp_dir.name + '/model_file.hdf5'
212+
213+
with h5py.File(recon_path, 'w') as hf:
214+
hf.create_dataset("reconstruction",
215+
data=model.astype('float32'))
216+
hf.close()
208217

209218
streamlines = track.track_ensemble(target_samples, atlas_data_wm_gm_int,
210219
labels_im_file, recon_path, sphere,

0 commit comments

Comments
 (0)