@@ -646,6 +646,12 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
646646 cost = np .zeros ((len (coords ), n_rotate ))
647647 inds = np .arange (len (coords ), dtype = 'int32' )
648648
649+ # precompute kd trees
650+ kdtrees = (
651+ spatial .cKDTree (coords [hemiid == 0 ]),
652+ spatial .cKDTree (coords [hemiid == 1 ])
653+ )
654+
649655 # generate rotations and resampling array!
650656 msg , warned = '' , False
651657 for n in range (n_rotate ):
@@ -697,7 +703,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
697703 # huge thanks to https://stackoverflow.com/a/47779290 for this
698704 # memory-efficient method
699705 else :
700- dist , col = spatial . cKDTree (coor @ rot ). query ( coor , 1 )
706+ dist , col = kdtrees [ h ]. query (coor @ rot , 1 )
701707 cost [hinds , n ] = dist
702708
703709 resampled [hinds ] = inds [hinds ][col ]
@@ -722,4 +728,7 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
722728 if verbose :
723729 print (' ' * len (msg ) + '\b ' * len (msg ), end = '' , flush = True )
724730
725- return spinsamples , cost
731+ if return_cost :
732+ return spinsamples , cost
733+
734+ return spinsamples
0 commit comments