Skip to content

Commit e4ce721

Browse files
author
CindeeM
committed
BF: fix code in format_matrix, update tests
1 parent 9fe476b commit e4ce721

File tree

2 files changed

+43
-15
lines changed

2 files changed

+43
-15
lines changed

brainx/tests/test_util.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,19 @@ def test_cost2thresh(self):
9292
npt.assert_almost_equal(thr, real_thr, decimal=7)
9393

9494
def test_format_matrix(self):
95-
thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0,
95+
bool_matrix = util.format_matrix2(self.data_5d, 0,0,0,
9696
self.lookup, self.costs[100])
97-
print thresh_matrix.sum()
98-
npt.assert_equal(thresh_matrix.sum(), 100 -1)
97+
npt.assert_equal(bool_matrix.sum(), 100 -1)
98+
thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0,
99+
self.lookup, self.costs[100],asbool = False)
100+
npt.assert_equal(bool_matrix.sum()== thresh_matrix.sum(), False)
101+
npt.assert_almost_equal(thresh_matrix.sum(),
102+
94.183321784530804, decimal=7)
103+
## test format_matrix call on format_matrix2
104+
bool_matrix_sm = util.format_matrix(self.data_5d[0].squeeze(),
105+
0,0, self.lookup[0].squeeze(), self.costs[100])
106+
npt.assert_equal(bool_matrix.sum(), bool_matrix_sm.sum())
107+
99108

100109
def test_threshold_adjacency_matrix(self):
101110
adj_matrix = self.data_5d[0,0,0].squeeze()

brainx/util.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def slice_data(data, sub, block, subcond=None):
3737
return data[subcond, block, sub]
3838

3939

40-
def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
40+
def format_matrix(data, s, b, lk, co, idc=[], costlist=[],
41+
nouptri=False, asbool=True):
4142
""" Function which thresholds the adjacency matrix for a particular
4243
subject and particular block, using lookuptable to find thresholds,
4344
cost value to find threshold, costlist
@@ -46,24 +47,37 @@ def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
4647
Parameters
4748
-----------
4849
data : full data array 4D (block, sub, node, node)
49-
s = subject
50-
b = block
51-
lk = lookup table for study
52-
co = cost value to threshold at
53-
idc = index of ideal cost
54-
costlist = list (size num_edges) with ordered values used to find
50+
s : int
51+
subject
52+
b : int
53+
block
54+
lk : numpy array
55+
lookup table for study
56+
co : int
57+
cost value to threshold at
58+
idc : int
59+
index of ideal cost
60+
costlist : list
61+
list (size num_edges) with ordered values used to find
5562
threshold to control number of edges
56-
nouptri = if False only keeps upper tri, True yields symmetric matrix
63+
nouptri : bool
64+
if False only keeps upper tri, True yields symmetric matrix
65+
asbool : bool
66+
if True return boolean mask, otherwise returns thesholded
67+
weight matrix
5768
"""
58-
cmat = data[b,s]
69+
cmat = slice_data(data, b,s)
5970
th = cost2thresh(co,s,b,lk,idc,costlist) #get the right threshold
6071
cmat = thresholded_arr(cmat,th,fill_val=0)
6172
if not nouptri:
6273
cmat = np.triu(cmat,1)
74+
if asbool:
75+
return ~(cmat == 0)
6376
return cmat
6477

6578

66-
def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
79+
def format_matrix2(data, s, sc, c, lk, co, idc=[],
80+
costlist=[], nouptri=False, asbool=True):
6781
""" Function which formats matrix for a particular subject and
6882
particular block (thresholds, upper-tris it) so that we can
6983
make a graph object out of it
@@ -88,14 +102,19 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
88102
list of possible costs
89103
nouptri : bool
90104
False zeros out diag and below, True returns symmetric matrix
105+
asbool : bool
106+
If true returns boolean mask, otherwise returns thresholded w
107+
weighted matrix
91108
"""
92109
cmat = slice_data(data, s, c, sc)
93110
th = cost2thresh2(co,s,sc,c,lk,[],idc,costlist) #get the right threshold
94111
cmat = thresholded_arr(cmat,th,fill_val=0)
95112
if not nouptri:
96113
cmat = np.triu(cmat,1)
97-
# return boolean mask
98-
return ~(cmat == 0)
114+
if asbool:
115+
# return boolean mask
116+
return ~(cmat == 0)
117+
return cmat
99118

100119
def threshold_adjacency_matrix(adj_matrix, cost):
101120
"""docstring for threshold_adjacency_matrix(adj_matrix, cost"""

0 commit comments

Comments
 (0)