Skip to content

Commit 9fe476b

Browse files
author
CindeeM
committed
NF: add test_threshold_adjacency_matrix, small cleanup
1 parent fa9dc53 commit 9fe476b

File tree

2 files changed

+23
-15
lines changed

2 files changed

+23
-15
lines changed

brainx/tests/test_util.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def setUp(self):
6767
ind = np.triu_indices(nnodes, k=1)
6868
nedges = (np.empty((nnodes, nnodes))[ind]).shape[0]
6969
costs, _, _ = util.cost_size(nnodes)
70+
self.nedges = nedges
7071
self.costs = costs
7172
self.lookup = np.zeros((nsubblocks, nblocks, nsub,2, nedges))
7273
bigcost =np.tile(costs[1:], nblocks*nsubblocks*nsub)
@@ -96,6 +97,17 @@ def test_format_matrix(self):
9697
print thresh_matrix.sum()
9798
npt.assert_equal(thresh_matrix.sum(), 100 -1)
9899

100+
def test_threshold_adjacency_matrix(self):
101+
adj_matrix = self.data_5d[0,0,0].squeeze()
102+
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, 0)
103+
npt.assert_equal(mask.sum(), 0)
104+
npt.assert_equal(real_cost, 0)
105+
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, .9)
106+
npt.assert_equal(mask.sum(), 1840)
107+
npt.assert_equal(real_cost, 0.9)
108+
109+
110+
99111
def test_apply_cost():
100112
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],
101113
[0.5, 0.0, 0.4, 0.1, 0.2],

brainx/util.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -54,19 +54,15 @@ def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
5454
costlist = list (size num_edges) with ordered values used to find
5555
threshold to control number of edges
5656
nouptri = if False only keeps upper tri, True yields symmetric matrix
57-
"""
58-
57+
"""
5958
cmat = data[b,s]
6059
th = cost2thresh(co,s,b,lk,idc,costlist) #get the right threshold
61-
62-
#cmat = replace_diag(cmat) #replace diagonals with zero
6360
cmat = thresholded_arr(cmat,th,fill_val=0)
64-
6561
if not nouptri:
6662
cmat = np.triu(cmat,1)
67-
6863
return cmat
6964

65+
7066
def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
7167
""" Function which formats matrix for a particular subject and
7268
particular block (thresholds, upper-tris it) so that we can
@@ -92,28 +88,28 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
9288
list of possible costs
9389
nouptri : bool
9490
False zeros out diag and below, True returns symmetric matrix
95-
"""
91+
"""
9692
cmat = slice_data(data, s, c, sc)
97-
#cmat = data[sc,c,s]
9893
th = cost2thresh2(co,s,sc,c,lk,[],idc,costlist) #get the right threshold
99-
100-
#cmat = replace_diag(cmat) #replace diagonals with zero
10194
cmat = thresholded_arr(cmat,th,fill_val=0)
102-
10395
if not nouptri:
10496
cmat = np.triu(cmat,1)
105-
10697
# return boolean mask
10798
return ~(cmat == 0)
10899

109100
def threshold_adjacency_matrix(adj_matrix, cost):
110101
"""docstring for threshold_adjacency_matrix(adj_matrix, cost"""
111-
112-
pass
102+
nnodes, _ = adj_matrix.shape
103+
ind = np.triu_indices(nnodes, 1)
104+
nedges = adj_matrix[ind].shape[0]
105+
lookup = make_cost_thresh_lookup(adj_matrix)
106+
cost_index = np.round(cost * float(nedges))
107+
thresh, actual_cost, round_cost = lookup[cost_index]
108+
return adj_matrix > thresh, actual_cost
113109

114110

115111
def all_positive(adjacency_matrix):
116-
""" checks if edge value sin adjacency matrix are all positive
112+
""" checks if edge values in adjacency matrix are all positive
117113
or positive and negative
118114
Returns
119115
-------

0 commit comments

Comments
 (0)