Skip to content

Commit cfed352

Browse files
committed
[TEST] Fixes tests
1 parent 6675571 commit cfed352

File tree

2 files changed

+14
-16
lines changed

2 files changed

+14
-16
lines changed

netneurotools/stats.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,9 +580,9 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
580580
>>> nnstats.gen_spinsamples(coords, hemi, n_rotate=1, seed=1,
581581
... method='original', check_duplicates=False)
582582
array([[0],
583-
[1],
583+
[0],
584584
[2],
585-
[2]], dtype=int32)
585+
[3]], dtype=int32)
586586
587587
While this is reasonable in most circumstances, if you feel incredibly
588588
strongly about having a perfect "permutation" (i.e., all indices appear
@@ -675,12 +675,6 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
675675
cost = np.zeros((len(coords), n_rotate))
676676
inds = np.arange(len(coords), dtype='int32')
677677

678-
# precompute kd trees
679-
kdtrees = (
680-
spatial.cKDTree(coords[hemiid == 0]),
681-
spatial.cKDTree(coords[hemiid == 1])
682-
)
683-
684678
# generate rotations and resampling array!
685679
msg, warned = '', False
686680
for n in range(n_rotate):
@@ -731,7 +725,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
731725
# huge thanks to https://stackoverflow.com/a/47779290 for this
732726
# memory-efficient method
733727
elif method == 'original':
734-
dist, col = kdtrees[h].query(coor @ rot, 1)
728+
dist, col = spatial.cKDTree(coor @ rot).query(coor, 1)
735729
cost[hinds, n] = dist
736730

737731
resampled[hinds] = inds[hinds][col]

netneurotools/tests/test_stats.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,21 @@ def test_gen_spinsamples():
135135
hemi = np.hstack([np.zeros(len(coords) // 2), np.ones(len(coords) // 2)])
136136

137137
# generate "normal" test spins
138-
spins, cost = stats.gen_spinsamples(coords, hemi, n_rotate=10, seed=1234)
138+
spins, cost = stats.gen_spinsamples(coords, hemi, n_rotate=10, seed=1234,
139+
return_cost=True)
139140
assert spins.shape == (len(coords), 10)
140141
assert cost.shape == (len(coords), 10)
141142

142143
# confirm that `exact` parameter functions as desired
143-
spin_exact, cost_exact = stats.gen_spinsamples(coords, hemi, n_rotate=10,
144-
exact=True, seed=1234)
145-
assert spin_exact.shape == (len(coords), 10)
146-
assert cost.shape == (len(coords), 10)
147-
for s in spin_exact.T:
148-
assert len(np.unique(s)) == len(s)
144+
for method in ['vasa', 'hungarian']:
145+
spin_exact, cost_exact = stats.gen_spinsamples(coords, hemi,
146+
n_rotate=10, seed=1234,
147+
method=method,
148+
return_cost=True)
149+
assert spin_exact.shape == (len(coords), 10)
150+
assert cost.shape == (len(coords), 10)
151+
for s in spin_exact.T:
152+
assert len(np.unique(s)) == len(s)
149153

150154
# confirm that check_duplicates will raise warnings
151155
# since spins aren't exact permutations we need to use 4C4 with repeats

0 commit comments

Comments
 (0)