Skip to content

Commit 6675571

Browse files
committed
[REF] Minor spin test refactoring
1 parent 856bfea commit 6675571

File tree

1 file changed

+71
-42
lines changed

1 file changed

+71
-42
lines changed

netneurotools/stats.py

Lines changed: 71 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,8 @@ def _gen_rotation(seed=None):
511511

512512

513513
def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
514-
exact=False, seed=None, verbose=False, return_cost=True):
514+
method='original', exact=False, seed=None, verbose=False,
515+
return_cost=False):
515516
"""
516517
Returns a resampling array for `coords` obtained from rotations / spins
517518
@@ -524,8 +525,8 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
524525
Due to irregular sampling of `coords` and the randomness of the rotations
525526
it is possible that some "rotations" may resample with replacement (i.e.,
526527
will not be a true permutation). The likelihood of this can be reduced by
527-
either increasing the sampling density of `coords` or setting the ``exact``
528-
parameter to True (though see Notes for more information on the latter).
528+
either increasing the sampling density of `coords` or changing the
529+
``method`` parameter (see Notes for more information on the latter).
529530
530531
Parameters
531532
----------
@@ -543,10 +544,12 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
543544
Whether to check for and attempt to avoid duplicate resamplings. A
544545
warnings will be raised if duplicates cannot be avoided. Setting to
545546
True may increase the runtime of this function! Default: True
546-
exact : bool, optional
547-
Whether each node/parcel/region should be uniquely re-assigned in every
548-
rotation. Setting to True will drastically increase the memory demands
549-
and runtime of this function! Default: False
547+
method : {'original', 'vasa', 'hungarian'}, optional
548+
Method by which to match non- and rotated coordinates. Specifying
549+
'original' will use the method described in [ST1]_. Specfying 'vasa'
550+
will use the method described in [ST4]_. Specfying 'hungarian' will use
551+
the Hungarian algorithm to minimize the global cost of reassignment
552+
(will dramatically increase runtime). Default: 'original'
550553
seed : {int, np.random.RandomState instance, None}, optional
551554
Seed for random number generation. Default: None
552555
verbose : bool, optional
@@ -574,29 +577,36 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
574577
>>> from netneurotools import stats as nnstats
575578
>>> coords = [[0, 0, 1], [1, 0, 0], [0, 0, 1], [1, 0, 0]]
576579
>>> hemi = [0, 0, 1, 1]
577-
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1)[0]
580+
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
581+
... method='original', check_duplicates=False)
578582
array([[0],
579-
[0],
583+
[1],
580584
[2],
581-
[3]], dtype=int32)
585+
[2]], dtype=int32)
582586
583587
While this is reasonable in most circumstances, if you feel incredibly
584588
strongly about having a perfect "permutation" (i.e., all indices appear
585-
once and exactly once in the resampling), you can set the ``exact``
586-
parameter to True:
589+
once and exactly once in the resampling), you can set the ``method``
590+
parameter to either 'vasa' or 'hungarian':
587591
588592
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
589-
... exact=True)[0]
593+
... method='vasa', check_duplicates=False)
590594
array([[1],
591595
[0],
592596
[2],
593597
[3]], dtype=int32)
598+
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
599+
... method='hungarian', check_duplicates=False)
600+
array([[0],
601+
[1],
602+
[2],
603+
[3]], dtype=int32)
594604
595-
Note that setting this parameter will *dramatically* increase the runtime
596-
of the function. Refer to [ST1]_ for information on why the default (i.e.,
597-
``exact`` set to False) suffices in most cases.
605+
Note that setting this parameter may increase the runtime of the function
606+
(especially for `method='hungarian'`). Refer to [ST1]_ for information on
607+
why the default (i.e., ``exact`` set to False) suffices in most cases.
598608
599-
For the original MATLAB implementation of this function refer to [ST4]_.
609+
For the original MATLAB implementation of this function refer to [ST5]_.
600610
601611
References
602612
----------
@@ -613,9 +623,28 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
613623
A Spectral Clustering Framework for Individual and Group Parcellation of
614624
Cortical Surfaces in Lobes. Frontiers in Neuroscience, 12, 354.
615625
616-
.. [ST4] https://github.com/spin-test/spin-test
626+
.. [ST4] Váša, F., Seidlitz, J., Romero-Garcia, R., Whitaker, K. J.,
627+
Rosenthal, G., Vértes, P. E., ... & Jones, P. B. (2018). Adolescent
628+
tuning of association cortex in human structural brain networks.
629+
Cerebral Cortex, 28(1), 281-294.
630+
631+
.. [ST5] https://github.com/spin-test/spin-test
617632
"""
618633

634+
methods = ['original', 'vasa', 'hungarian']
635+
if method not in methods:
636+
raise ValueError('Provided method "{}" invalid. Must be one of {}.'
637+
.format(method, methods))
638+
639+
if exact:
640+
warnings.warn('The `exact` parameter will no longer be supported in '
641+
'an upcoming release. Please use the `method` parameter '
642+
'instead.', DeprecationWarning, stacklevel=3)
643+
if exact == 'vasa' and method == 'original':
644+
method = 'vasa'
645+
elif exact and method == 'original':
646+
method = 'hungarian'
647+
619648
seed = check_random_state(seed)
620649

621650
coords = np.asanyarray(coords)
@@ -673,36 +702,35 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
673702
# if we need an "exact" mapping (i.e., each node needs to be
674703
# assigned EXACTLY once) then we have to calculate the full
675704
# distance matrix which is a nightmare with respect to memory
676-
# for anything that isn't parcellated data. that is, don't do
677-
# this with vertex coordinates!
678-
if exact:
705+
# for anything that isn't parcellated data.
706+
# that is, don't do this with vertex coordinates!
707+
if method == 'vasa':
708+
dist = spatial.distance_matrix(coor, coor @ rot)
709+
# min of max a la Vasa et al., 2018
710+
col = np.zeros(len(coor), dtype='int32')
711+
for r in range(len(dist)):
712+
# find parcel whose closest neighbor is farthest away
713+
# overall; assign to that
714+
row = dist.min(axis=1).argmax()
715+
col[row] = dist[row].argmin()
716+
cost[inds[hinds][row], n] = dist[row, col[row]]
717+
# set to -inf and inf so they can't be assigned again
718+
dist[row] = -np.inf
719+
dist[:, col[row]] = np.inf
720+
# optimization of total cost using Hungarian algorithm. this
721+
# may result in certain parcels having higher cost than with
722+
# `method='vasa'` but should always result in the total cost
723+
# being lower #tradeoffs
724+
elif method == 'hungarian':
679725
dist = spatial.distance_matrix(coor, coor @ rot)
680-
# min of max a la Vasa et al., 2017
681-
if exact == 'vasa':
682-
col = np.zeros(len(coor), dtype='int32')
683-
for r in range(len(dist)):
684-
# find parcel whose closest neighbor is farthest
685-
# away overall; assign to that
686-
row = dist.min(axis=1).argmax()
687-
col[row] = dist[row].argmin()
688-
cost[inds[hinds][row], n] = dist[row, col[row]]
689-
# set these to -inf and inf so they can't be
690-
# assigned again
691-
dist[row] = -np.inf
692-
dist[:, col[row]] = np.inf
693-
# optimization of total cost using Hungarian algorithm.
694-
# this may result in certain parcels having higher cost
695-
# than with `exact='vasa'` but should always result in the
696-
# total cost being lower #tradeoffs
697-
else:
698-
row, col = optimize.linear_sum_assignment(dist)
699-
cost[hinds, n] = dist[row, col]
726+
row, col = optimize.linear_sum_assignment(dist)
727+
cost[hinds, n] = dist[row, col]
700728
# if nodes can be assigned multiple targets, we can simply use
701729
# the absolute minimum of the distances (no optimization
702730
# required) which is _much_ lighter on memory
703731
# huge thanks to https://stackoverflow.com/a/47779290 for this
704732
# memory-efficient method
705-
else:
733+
elif method == 'original':
706734
dist, col = kdtrees[h].query(coor @ rot, 1)
707735
cost[hinds, n] = dist
708736

@@ -712,6 +740,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
712740
if check_duplicates:
713741
if np.any(np.all(resampled[:, None] == spinsamples[:, :n], 0)):
714742
duplicated = True
743+
# if our "spin" is identical to the input then that's no good
715744
elif np.all(resampled == inds):
716745
duplicated = True
717746

0 commit comments

Comments
 (0)