Skip to content

Commit 9f7cddf

Browse files
authored
Merge pull request #77 from rmarkello/fix/spins
[FIX] Minor change to spin sample generation
2 parents e93b726 + cfed352 commit 9f7cddf

File tree

2 files changed

+84
-48
lines changed

2 files changed

+84
-48
lines changed

netneurotools/stats.py

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,8 @@ def _gen_rotation(seed=None):
511511

512512

513513
def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
514-
exact=False, seed=None, verbose=False, return_cost=True):
514+
method='original', exact=False, seed=None, verbose=False,
515+
return_cost=False):
515516
"""
516517
Returns a resampling array for `coords` obtained from rotations / spins
517518
@@ -524,8 +525,8 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
524525
Due to irregular sampling of `coords` and the randomness of the rotations
525526
it is possible that some "rotations" may resample with replacement (i.e.,
526527
will not be a true permutation). The likelihood of this can be reduced by
527-
either increasing the sampling density of `coords` or setting the ``exact``
528-
parameter to True (though see Notes for more information on the latter).
528+
either increasing the sampling density of `coords` or changing the
529+
``method`` parameter (see Notes for more information on the latter).
529530
530531
Parameters
531532
----------
@@ -543,10 +544,12 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
543544
Whether to check for and attempt to avoid duplicate resamplings. A
544545
warnings will be raised if duplicates cannot be avoided. Setting to
545546
True may increase the runtime of this function! Default: True
546-
exact : bool, optional
547-
Whether each node/parcel/region should be uniquely re-assigned in every
548-
rotation. Setting to True will drastically increase the memory demands
549-
and runtime of this function! Default: False
547+
method : {'original', 'vasa', 'hungarian'}, optional
548+
Method by which to match non- and rotated coordinates. Specifying
549+
'original' will use the method described in [ST1]_. Specfying 'vasa'
550+
will use the method described in [ST4]_. Specfying 'hungarian' will use
551+
the Hungarian algorithm to minimize the global cost of reassignment
552+
(will dramatically increase runtime). Default: 'original'
550553
seed : {int, np.random.RandomState instance, None}, optional
551554
Seed for random number generation. Default: None
552555
verbose : bool, optional
@@ -574,29 +577,36 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
574577
>>> from netneurotools import stats as nnstats
575578
>>> coords = [[0, 0, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
576579
>>> hemi = [0, 0, 1, 1]
577-
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1)[0]
580+
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
581+
... method='original', check_duplicates=False)
578582
array([[0],
579583
[0],
580584
[2],
581585
[3]], dtype=int32)
582586
583587
While this is reasonable in most circumstances, if you feel incredibly
584588
strongly about having a perfect "permutation" (i.e., all indices appear
585-
once and exactly once in the resampling), you can set the ``exact``
586-
parameter to True:
589+
once and exactly once in the resampling), you can set the ``method``
590+
parameter to either 'vasa' or 'hungarian':
587591
588592
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
589-
... exact=True)[0]
593+
... method='vasa', check_duplicates=False)
590594
array([[1],
591595
[0],
592596
[2],
593597
[3]], dtype=int32)
598+
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
599+
... method='hungarian', check_duplicates=False)
600+
array([[0],
601+
[1],
602+
[2],
603+
[3]], dtype=int32)
594604
595-
Note that setting this parameter will *dramatically* increase the runtime
596-
of the function. Refer to [ST1]_ for information on why the default (i.e.,
597-
``exact`` set to False) suffices in most cases.
605+
Note that setting this parameter may increase the runtime of the function
606+
(especially for `method='hungarian'`). Refer to [ST1]_ for information on
607+
why the default (i.e., ``exact`` set to False) suffices in most cases.
598608
599-
For the original MATLAB implementation of this function refer to [ST4]_.
609+
For the original MATLAB implementation of this function refer to [ST5]_.
600610
601611
References
602612
----------
@@ -613,9 +623,28 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
613623
A Spectral Clustering Framework for Individual and Group Parcellation of
614624
Cortical Surfaces in Lobes. Frontiers in Neuroscience, 12, 354.
615625
616-
.. [ST4] https://github.com/spin-test/spin-test
626+
.. [ST4] Váša, F., Seidlitz, J., Romero-Garcia, R., Whitaker, K. J.,
627+
Rosenthal, G., Vértes, P. E., ... & Jones, P. B. (2018). Adolescent
628+
tuning of association cortex in human structural brain networks.
629+
Cerebral Cortex, 28(1), 281-294.
630+
631+
.. [ST5] https://github.com/spin-test/spin-test
617632
"""
618633

634+
methods = ['original', 'vasa', 'hungarian']
635+
if method not in methods:
636+
raise ValueError('Provided method "{}" invalid. Must be one of {}.'
637+
.format(method, methods))
638+
639+
if exact:
640+
warnings.warn('The `exact` parameter will no longer be supported in '
641+
'an upcoming release. Please use the `method` parameter '
642+
'instead.', DeprecationWarning, stacklevel=3)
643+
if exact == 'vasa' and method == 'original':
644+
method = 'vasa'
645+
elif exact and method == 'original':
646+
method = 'hungarian'
647+
619648
seed = check_random_state(seed)
620649

621650
coords = np.asanyarray(coords)
@@ -667,36 +696,35 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
667696
# if we need an "exact" mapping (i.e., each node needs to be
668697
# assigned EXACTLY once) then we have to calculate the full
669698
# distance matrix which is a nightmare with respect to memory
670-
# for anything that isn't parcellated data. that is, don't do
671-
# this with vertex coordinates!
672-
if exact:
699+
# for anything that isn't parcellated data.
700+
# that is, don't do this with vertex coordinates!
701+
if method == 'vasa':
673702
dist = spatial.distance_matrix(coor, coor @ rot)
674-
# min of max a la Vasa et al., 2017
675-
if exact == 'vasa':
676-
col = np.zeros(len(coor), dtype='int32')
677-
for r in range(len(dist)):
678-
# find parcel whose closest neighbor is farthest
679-
# away overall; assign to that
680-
row = dist.min(axis=1).argmax()
681-
col[row] = dist[row].argmin()
682-
cost[inds[hinds][row], n] = dist[row, col[row]]
683-
# set these to -inf and inf so they can't be
684-
# assigned again
685-
dist[row] = -np.inf
686-
dist[:, col[row]] = np.inf
687-
# optimization of total cost using Hungarian algorithm.
688-
# this may result in certain parcels having higher cost
689-
# than with `exact='vasa'` but should always result in the
690-
# total cost being lower #tradeoffs
691-
else:
692-
row, col = optimize.linear_sum_assignment(dist)
693-
cost[hinds, n] = dist[row, col]
703+
# min of max a la Vasa et al., 2018
704+
col = np.zeros(len(coor), dtype='int32')
705+
for r in range(len(dist)):
706+
# find parcel whose closest neighbor is farthest away
707+
# overall; assign to that
708+
row = dist.min(axis=1).argmax()
709+
col[row] = dist[row].argmin()
710+
cost[inds[hinds][row], n] = dist[row, col[row]]
711+
# set to -inf and inf so they can't be assigned again
712+
dist[row] = -np.inf
713+
dist[:, col[row]] = np.inf
714+
# optimization of total cost using Hungarian algorithm. this
715+
# may result in certain parcels having higher cost than with
716+
# `method='vasa'` but should always result in the total cost
717+
# being lower #tradeoffs
718+
elif method == 'hungarian':
719+
dist = spatial.distance_matrix(coor, coor @ rot)
720+
row, col = optimize.linear_sum_assignment(dist)
721+
cost[hinds, n] = dist[row, col]
694722
# if nodes can be assigned multiple targets, we can simply use
695723
# the absolute minimum of the distances (no optimization
696724
# required) which is _much_ lighter on memory
697725
# huge thanks to https://stackoverflow.com/a/47779290 for this
698726
# memory-efficient method
699-
else:
727+
elif method == 'original':
700728
dist, col = spatial.cKDTree(coor @ rot).query(coor, 1)
701729
cost[hinds, n] = dist
702730

@@ -706,6 +734,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
706734
if check_duplicates:
707735
if np.any(np.all(resampled[:, None] == spinsamples[:, :n], 0)):
708736
duplicated = True
737+
# if our "spin" is identical to the input then that's no good
709738
elif np.all(resampled == inds):
710739
duplicated = True
711740

@@ -722,4 +751,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
722751
if verbose:
723752
print(' ' * len(msg) + '\b' * len(msg), end='', flush=True)
724753

725-
return spinsamples, cost
754+
if return_cost:
755+
return spinsamples, cost
756+
757+
return spinsamples

netneurotools/tests/test_stats.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,21 @@ def test_gen_spinsamples():
135135
hemi = np.hstack([np.zeros(len(coords) // 2), np.ones(len(coords) // 2)])
136136

137137
# generate "normal" test spins
138-
spins, cost = stats.gen_spinsamples(coords, hemi, n_rotate=10, seed=1234)
138+
spins, cost = stats.gen_spinsamples(coords, hemi, n_rotate=10, seed=1234,
139+
return_cost=True)
139140
assert spins.shape == (len(coords), 10)
140141
assert cost.shape == (len(coords), 10)
141142

142143
# confirm that `exact` parameter functions as desired
143-
spin_exact, cost_exact = stats.gen_spinsamples(coords, hemi, n_rotate=10,
144-
exact=True, seed=1234)
145-
assert spin_exact.shape == (len(coords), 10)
146-
assert cost.shape == (len(coords), 10)
147-
for s in spin_exact.T:
148-
assert len(np.unique(s)) == len(s)
144+
for method in ['vasa', 'hungarian']:
145+
spin_exact, cost_exact = stats.gen_spinsamples(coords, hemi,
146+
n_rotate=10, seed=1234,
147+
method=method,
148+
return_cost=True)
149+
assert spin_exact.shape == (len(coords), 10)
150+
assert cost.shape == (len(coords), 10)
151+
for s in spin_exact.T:
152+
assert len(np.unique(s)) == len(s)
149153

150154
# confirm that check_duplicates will raise warnings
151155
# since spins aren't exact permutations we need to use 4C4 with repeats

0 commit comments

Comments
 (0)