Skip to content

Commit 313eeac

Browse files
author
CindeeM
committed
RF: finish cleaning up cost to thresh, caught corner case in slice data
1 parent 0bccbb8 commit 313eeac

File tree

2 files changed

+25
-41
lines changed

2 files changed

+25
-41
lines changed

brainx/tests/test_util.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,17 +72,24 @@ def setUp(self):
7272
for sblock in range(nsubblocks):
7373
for block in range(nblocks):
7474
for sid in range(nsub):
75-
tmp = data_5d[sblock, block, sid]
76-
self.lookup[sblock,block,sid,0,:] = tmp[ind]
75+
tmp = self.data_5d[sblock, block, sid]
76+
self.lookup[sblock,block,sid,0,:] = sorted(tmp[ind],
77+
reverse=True)
7778

78-
def test_cost2thresh2(self):
79-
thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup)
80-
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
81-
82-
def test_cost2thresh(self):
83-
lookup = self.lookup[0].squeeze()
84-
thr = util.cost2thresh(self.costs[100],0,0,0,lookup)
85-
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
79+
def test_cost2thresh2(self):
80+
thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup)
81+
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
82+
83+
def test_cost2thresh(self):
84+
lookup = self.lookup[0].squeeze()
85+
thr = util.cost2thresh(self.costs[100],0,0,0,lookup)
86+
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
87+
88+
def test_format_matrix(self):
89+
thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0,
90+
self.lookup, self.costs[100])
91+
print thresh_matrix.sum()
92+
npt.assert_equal(thresh_matrix.sum(), 22)
8693

8794
def test_apply_cost():
8895
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],

brainx/util.py

Lines changed: 8 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def slice_data(data, sub, block, subcond=None):
3232
adjacency_matrix : numpy array
3333
symmetric numpy array (innode, nnode)
3434
"""
35-
if subcond:
35+
if not subcond is None:
3636
return data[subcond, block, sub]
3737
return data[block, sub]
3838

@@ -91,16 +91,17 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
9191
nouptri : bool
9292
False zeros out diag and below, True returns symmetric matrix
9393
"""
94-
95-
cmat = data[sc,c,s]
94+
cmat = slice_data(data, s, c, sc)
95+
#cmat = data[sc,c,s]
9696
th = cost2thresh2(co,s,sc,c,lk,[],idc,costlist) #get the right threshold
9797

9898
#cmat = replace_diag(cmat) #replace diagonals with zero
9999
cmat = thresholded_arr(cmat,th,fill_val=0)
100100

101101
if not nouptri:
102102
cmat = np.triu(cmat,1)
103-
103+
104+
# return boolean mask
104105
return cmat
105106

106107
def threshold_adjacency_matrix(adj_matrix, cost):
@@ -606,32 +607,8 @@ def cost2thresh(cost, sub, bl, lk, idc=[], costlist=[]):
606607
be registered.
607608
608609
"""
609-
return cost2thresh2(cost, sub, bl, axis0=None, lk=lk, last = None, idc=idc,costlist = costlist)
610-
# For this subject and block, find the indices corresponding to this cost.
611-
# Note there may be more than one such index. There will be no such
612-
# indices if cost is not a value in the array.
613-
ind = np.where(lk[bl][sub][1] == cost)
614-
# The possibility of multiple (or no) indices implies multiple (or no)
615-
# thresholds may be acquired here.
616-
th = lk[bl][sub][0][ind]
617-
n_thresholds = len(th)
618-
if n_thresholds > 1:
619-
th=th[0]
620-
print(''.join(['Subject %s has multiple thresholds in block %d ',
621-
'corresponding to a cost of %f. The smallest is being',
622-
' used.']) % (sub, bl, cost))
623-
elif n_thresholds < 1:
624-
idc = idc - 1
625-
newcost = costlist[idc]
626-
th = cost2thresh(newcost, sub, bl, lk, idc, costlist)
627-
print(''.join(['Subject %s does not have a threshold in block %d ',
628-
'corresponding to a cost of %f. The threshold ',
629-
'matching the nearest previous cost in costlist is ',
630-
'being used.']) % (sub, block, cost))
631-
else:
632-
th=th[0]
633-
return th
634-
610+
return cost2thresh2(cost, sub, bl, axis0=None,
611+
lk=lk, last = None, idc=idc,costlist = costlist)
635612

636613
def cost2thresh2(cost, sub, axis1, axis0, lk,
637614
last = None, idc = [], costlist=[]):
@@ -667,7 +644,7 @@ def cost2thresh2(cost, sub, axis1, axis0, lk,
667644

668645
subject_lookup = slice_data(lk, sub, axis0, subcond=axis1)
669646
index = np.where(subject_lookup[1] == cost)
670-
threshold = subject_lookup[0][ind]
647+
threshold = subject_lookup[0][index]
671648

672649
if len(threshold) > 1:
673650
threshold = threshold[0]

0 commit comments

Comments
 (0)