Skip to content

Commit 1fc88dd

Browse files
authored
Merge pull request #66 from rmarkello/fix/freesurfer
[REF] Updates FreeSurfer-related spin test functionality
2 parents 368cff1 + 510444d commit 1fc88dd

File tree

2 files changed

+74
-47
lines changed

2 files changed

+74
-47
lines changed

netneurotools/freesurfer.py

Lines changed: 67 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def find_parcel_centroids(*, lhannot, rhannot, version='fsaverage',
139139
defined in `lhannot` and `rhannot`
140140
hemiid : (N,) numpy.ndarray
141141
Array denoting hemisphere designation of coordinates in `centroids`,
142-
where `hemiid=0` denotes the right and `hemiid=1` the left hemisphere
142+
where `hemiid=0` denotes the left and `hemiid=1` the right hemisphere
143143
"""
144144

145145
if drop is None:
@@ -203,13 +203,14 @@ def parcels_to_vertices(data, *, lhannot, rhannot, drop=None):
203203
]
204204
drop = _decode_list(drop)
205205

206-
start = end = 0
207-
projected = []
206+
data = np.vstack(data)
208207

209208
# check this so we're not unduly surprised by anything...
210-
expected = 0
209+
n_vert = expected = 0
211210
for a in [lhannot, rhannot]:
212-
names = _decode_list(read_annot(a)[-1])
211+
vn, _, names = read_annot(a)
212+
n_vert += len(vn)
213+
names = _decode_list(names)
213214
expected += len(names) - len(set(drop) & set(names))
214215
if expected != len(data):
215216
raise ValueError('Number of parcels in provided annotation files '
@@ -218,6 +219,8 @@ def parcels_to_vertices(data, *, lhannot, rhannot, drop=None):
218219
' RECEIVED: {} parcels'
219220
.format(expected, len(data)))
220221

222+
projected = np.zeros((n_vert, data.shape[-1]), dtype=data.dtype)
223+
start = end = n_vert = 0
221224
for annot in [lhannot, rhannot]:
222225
# read files and update end index for `data`
223226
labels, ctab, names = read_annot(annot)
@@ -228,13 +231,14 @@ def parcels_to_vertices(data, *, lhannot, rhannot, drop=None):
228231
# get indices of unknown and corpuscallosum and insert NaN values
229232
inds = sorted([names.index(f) for f in todrop])
230233
inds = [f - n for n, f in enumerate(inds)]
231-
currdata = np.insert(data[start:end], inds, np.nan)
234+
currdata = np.insert(data[start:end], inds, np.nan, axis=0)
232235

233236
# project to vertices and store
234-
projected.append(currdata[labels])
237+
projected[n_vert:n_vert + len(labels), :] = currdata[labels]
235238
start = end
239+
n_vert += len(labels)
236240

237-
return np.hstack(projected)
241+
return np.squeeze(projected)
238242

239243

240244
def vertices_to_parcels(data, *, lhannot, rhannot, drop=None):
@@ -270,18 +274,23 @@ def vertices_to_parcels(data, *, lhannot, rhannot, drop=None):
270274
]
271275
drop = _decode_list(drop)
272276

273-
start = end = 0
274-
reduced = []
277+
data = np.vstack(data)
275278

276-
# check this so we're not unduly surprised by anything...
277-
expected = sum([len(read_annot(a)[0]) for a in [lhannot, rhannot]])
279+
n_parc = expected = 0
280+
for a in [lhannot, rhannot]:
281+
vn, _, names = read_annot(a)
282+
expected += len(vn)
283+
names = _decode_list(names)
284+
n_parc += len(names) - len(set(drop) & set(names))
278285
if expected != len(data):
279286
raise ValueError('Number of vertices in provided annotation files '
280287
'differs from size of vertex-level data array.\n'
281288
' EXPECTED: {} vertices\n'
282289
' RECEIVED: {} vertices'
283290
.format(expected, len(data)))
284291

292+
reduced = np.zeros((n_parc, data.shape[-1]), dtype=data.dtype)
293+
start = end = n_parc = 0
285294
for annot in [lhannot, rhannot]:
286295
# read files and update end index for `data`
287296
labels, ctab, names = read_annot(annot)
@@ -290,33 +299,36 @@ def vertices_to_parcels(data, *, lhannot, rhannot, drop=None):
290299
indices = np.unique(labels)
291300
end += len(labels)
292301

293-
# get average of vertex-level data within parcels
294-
# set all NaN values to 0 before calling `_stats` because we are
295-
# returning sums, so the 0 values won't impact the sums (if we left
296-
# the NaNs then all parcels with even one NaN entry would be NaN)
297-
currdata = np.squeeze(data[start:end])
298-
isna = np.isnan(currdata)
299-
counts, sums = _stats(np.nan_to_num(currdata), labels, indices)
300-
301-
# however, we do need to account for the NaN values in the counts
302-
# so that our means are similar to what we'd get from e.g., np.nanmean
303-
# here, our "sums" are the counts of NaN values in our parcels
304-
_, nacounts = _stats(isna, labels, indices)
305-
counts = (np.asanyarray(counts, dtype=float)
306-
- np.asanyarray(nacounts, dtype=float))
307-
308-
with np.errstate(divide='ignore', invalid='ignore'):
309-
currdata = sums / counts
310-
311-
# get indices of unkown and corpuscallosum and delete from parcels
312-
inds = sorted([names.index(f) for f in set(drop) & set(names)])
313-
currdata = np.delete(currdata, inds)
314-
315-
# store parcellated data
316-
reduced.append(currdata)
302+
for idx in range(data.shape[-1]):
303+
# get average of vertex-level data within parcels
304+
# set all NaN values to 0 before calling `_stats` because we are
305+
# returning sums, so the 0 values won't impact the sums (if we left
306+
# the NaNs then all parcels with even one NaN entry would be NaN)
307+
currdata = np.squeeze(data[start:end, idx])
308+
isna = np.isnan(currdata)
309+
counts, sums = _stats(np.nan_to_num(currdata), labels, indices)
310+
311+
# however, we do need to account for the NaN values in the counts
312+
# so that our means are similar to what we'd get from e.g.,
313+
# np.nanmean here, our "sums" are the counts of NaN values in our
314+
# parcels
315+
_, nacounts = _stats(isna, labels, indices)
316+
counts = (np.asanyarray(counts, dtype=float)
317+
- np.asanyarray(nacounts, dtype=float))
318+
319+
with np.errstate(divide='ignore', invalid='ignore'):
320+
currdata = sums / counts
321+
322+
# get indices of unkown and corpuscallosum and delete from parcels
323+
inds = sorted([names.index(f) for f in set(drop) & set(names)])
324+
currdata = np.delete(currdata, inds)
325+
326+
# store parcellated data
327+
reduced[n_parc:n_parc + len(names) - len(inds), idx] = currdata
317328
start = end
329+
n_parc += len(names) - len(inds)
318330

319-
return np.hstack(reduced)
331+
return np.squeeze(reduced)
320332

321333

322334
def _get_fsaverage_coords(version='fsaverage', surface='sphere'):
@@ -351,7 +363,8 @@ def _get_fsaverage_coords(version='fsaverage', surface='sphere'):
351363

352364

353365
def spin_data(data, *, lhannot, rhannot, version='fsaverage', n_rotate=1000,
354-
drop=None, seed=None, verbose=False, return_cost=False):
366+
spins=None, drop=None, seed=None, verbose=False,
367+
return_cost=False):
355368
"""
356369
Projects parcellated `data` to surface, rotates, and re-parcellates
357370
@@ -417,8 +430,18 @@ def spin_data(data, *, lhannot, rhannot, version='fsaverage', n_rotate=1000,
417430
' FSAVERAGE: {} vertices'
418431
.format(len(vertices), len(coords)))
419432

420-
spins, cost = gen_spinsamples(coords, hemiid, n_rotate=n_rotate,
421-
seed=seed, verbose=verbose)
433+
if spins is None:
434+
spins, cost = gen_spinsamples(coords, hemiid, n_rotate=n_rotate,
435+
seed=seed, verbose=verbose)
436+
else:
437+
spins = np.asarray(spins, dtype='int32')
438+
if spins.shape[-1] != n_rotate:
439+
raise ValueError('Provided `spins` does not match number of '
440+
'requested rotations with `n_rotate`. Please '
441+
'check inputs and try again.')
442+
if return_cost:
443+
raise ValueError('Cannot `return_cost` when `spins` are provided.')
444+
422445
spun = np.zeros((len(data), n_rotate))
423446
for n in range(n_rotate):
424447
spun[:, n] = vertices_to_parcels(vertices[spins[:, n]],
@@ -432,7 +455,7 @@ def spin_data(data, *, lhannot, rhannot, version='fsaverage', n_rotate=1000,
432455

433456

434457
def spin_parcels(*, lhannot, rhannot, version='fsaverage', n_rotate=1000,
435-
drop=None, seed=None, verbose=False, return_cost=False):
458+
drop=None, seed=None, return_cost=False, **kwargs):
436459
"""
437460
Rotates parcels in `{lh,rh}annot` and re-assigns based on maximum overlap
438461
@@ -456,13 +479,11 @@ def spin_parcels(*, lhannot, rhannot, version='fsaverage', n_rotate=1000,
456479
will be inserted in place of the these regions in the returned data. If
457480
not specified, 'unknown' and 'corpuscallosum' are assumed to not be
458481
present. Default: None
459-
seed : {int, np.random.RandomState instance, None}, optional
460-
Seed for random number generation. Default: None
461-
verbose : bool, optional
462-
Whether to print occasional status messages. Default: False
463482
return_cost : bool, optional
464483
Whether to return cost array (specified as Euclidean distance) for each
465484
coordinate for each rotation Default: True
485+
kwargs : key-value, optional
486+
Key-value pairs passed to :func:`netneurotools.stats.gen_spinsamples`
466487
467488
Returns
468489
-------
@@ -519,8 +540,7 @@ def overlap(vals):
519540
.format(len(vertices), len(coords)))
520541

521542
# spin and assign regions based on max overlap
522-
spins, cost = gen_spinsamples(coords, hemiid, n_rotate=n_rotate,
523-
seed=seed, verbose=verbose)
543+
spins, cost = gen_spinsamples(coords, hemiid, n_rotate=n_rotate, **kwargs)
524544
regions = np.zeros((len(labels[mask]), n_rotate), dtype='int32')
525545
for n in range(n_rotate):
526546
regions[:, n] = labeled_comprehension(vertices[spins[:, n]], vertices,

netneurotools/tests/test_freesurfer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ def test_project_reduce_vertices(cammoun_surf, scale, parcels):
3939
reduced = freesurfer.vertices_to_parcels(projected, rhannot=rh, lhannot=lh)
4040
assert np.allclose(data, reduced)
4141

42+
# can we do this with multi-dimensional data, too?
43+
data = np.random.rand(parcels, 2)
44+
projected = freesurfer.parcels_to_vertices(data, rhannot=rh, lhannot=lh)
45+
assert projected.shape == (327684, 2)
46+
reduced = freesurfer.vertices_to_parcels(projected, rhannot=rh, lhannot=lh)
47+
assert np.allclose(data, reduced)
48+
4249
# number of parcels != annotation spec
4350
with pytest.raises(ValueError):
4451
freesurfer.parcels_to_vertices(np.random.rand(parcels + 1),

0 commit comments

Comments
 (0)