Skip to content

Commit 0bccbb8

Browse files
author
CindeeM
committed
RF: cost2thresh calls cost2thresh2 to minimize code duplication, added tests
1 parent 9fee68f commit 0bccbb8

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

brainx/tests/test_util.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#-----------------------------------------------------------------------------
44
# Imports
55
#-----------------------------------------------------------------------------
6-
6+
import unittest
77

88
# Third party
99
import nose.tools as nt
@@ -55,6 +55,34 @@ def test_cost_size():
5555
n_nodes = 5
5656
npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes)
5757

58+
class TestCost2Thresh(unittest.TestCase):
59+
def setUp(self):
60+
nnodes, nsub, nblocks, nsubblocks = 45, 20, 6, 2
61+
prng = np.random.RandomState(42)
62+
self.data_5d = prng.random_sample((nsubblocks, nblocks,
63+
nsub, nnodes, nnodes))
64+
ind = np.triu_indices(nnodes, k=1)
65+
nedges = (np.empty((nnodes, nnodes))[ind]).shape[0]
66+
costs, _, _ = util.cost_size(nnodes)
67+
self.costs = costs
68+
self.lookup = np.zeros((nsubblocks, nblocks, nsub,2, nedges))
69+
bigcost =np.tile(costs[1:], nblocks*nsubblocks*nsub)
70+
bigcost.shape = (nsubblocks, nblocks, nsub, nedges)
71+
self.lookup[:,:,:,1,:] = bigcost
72+
for sblock in range(nsubblocks):
73+
for block in range(nblocks):
74+
for sid in range(nsub):
75+
tmp = data_5d[sblock, block, sid]
76+
self.lookup[sblock,block,sid,0,:] = tmp[ind]
77+
78+
def test_cost2thresh2(self):
79+
thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup)
80+
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
81+
82+
def test_cost2thresh(self):
83+
lookup = self.lookup[0].squeeze()
84+
thr = util.cost2thresh(self.costs[100],0,0,0,lookup)
85+
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
5886

5987
def test_apply_cost():
6088
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],

brainx/util.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,7 @@ def cost2thresh(cost, sub, bl, lk, idc=[], costlist=[]):
606606
be registered.
607607
608608
"""
609+
return cost2thresh2(cost, sub, bl, axis0=None, lk=lk, last = None, idc=idc,costlist = costlist)
609610
# For this subject and block, find the indices corresponding to this cost.
610611
# Note there may be more than one such index. There will be no such
611612
# indices if cost is not a value in the array.
@@ -632,7 +633,8 @@ def cost2thresh(cost, sub, bl, lk, idc=[], costlist=[]):
632633
return th
633634

634635

635-
def cost2thresh2(cost, sub, sc, c, lk, last = None, idc = [], costlist=[]):
636+
def cost2thresh2(cost, sub, axis1, axis0, lk,
637+
last = None, idc = [], costlist=[]):
636638
"""A definition for loading the lookup table and finding the threshold
637639
associated with a particular cost for a particular subject in a
638640
particular block of data
@@ -663,7 +665,7 @@ def cost2thresh2(cost, sub, sc, c, lk, last = None, idc = [], costlist=[]):
663665
threshold : float
664666
threshold value for this cost"""
665667

666-
subject_lookup = slice_data(lk, sub, c, subcond=sc)
668+
subject_lookup = slice_data(lk, sub, axis0, subcond=axis1)
667669
index = np.where(subject_lookup[1] == cost)
668670
threshold = subject_lookup[0][ind]
669671

@@ -676,7 +678,7 @@ def cost2thresh2(cost, sub, sc, c, lk, last = None, idc = [], costlist=[]):
676678
elif len(threshold) < 1:
677679
idc = idc-1
678680
newcost = costlist[idc]
679-
threshold = cost2thresh2(newcost, sub, sc, c, lk,
681+
threshold = cost2thresh2(newcost, sub, axis1, axis0, lk,
680682
idc=idc, costlist = costlist)
681683
print(' '.join(['Subject %s does not have cost at %s'%(sub, cost),
682684
'index 1: %s, index 2: %s'%(c, sc),

0 commit comments

Comments
 (0)