Skip to content

Commit 856bfea

Browse files
committed
[FIX] Fixes ordering of spins for spintest
1 parent 8e02a5c commit 856bfea

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

netneurotools/stats.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,12 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
646646
cost = np.zeros((len(coords), n_rotate))
647647
inds = np.arange(len(coords), dtype='int32')
648648

649+
# precompute kd trees
650+
kdtrees = (
651+
spatial.cKDTree(coords[hemiid == 0]),
652+
spatial.cKDTree(coords[hemiid == 1])
653+
)
654+
649655
# generate rotations and resampling array!
650656
msg, warned = '', False
651657
for n in range(n_rotate):
@@ -697,7 +703,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
697703
# huge thanks to https://stackoverflow.com/a/47779290 for this
698704
# memory-efficient method
699705
else:
700-
dist, col = spatial.cKDTree(coor @ rot).query(coor, 1)
706+
dist, col = kdtrees[h].query(coor @ rot, 1)
701707
cost[hinds, n] = dist
702708

703709
resampled[hinds] = inds[hinds][col]
@@ -722,4 +728,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
722728
if verbose:
723729
print(' ' * len(msg) + '\b' * len(msg), end='', flush=True)
724730

725-
return spinsamples, cost
731+
if return_cost:
732+
return spinsamples, cost
733+
734+
return spinsamples

0 commit comments

Comments
 (0)