Skip to content

Commit e5ea623

Browse files
author
cindeem
committed
Merge pull request #16 from cgallen/fix_diag
added uptri flag and diagonal zero-ing, updated tests
2 parents d47460a + cde1ed8 commit e5ea623

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

brainx/tests/test_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_threshold_adjacency_matrix(self):
112112
npt.assert_equal(mask.sum(), 0)
113113
npt.assert_equal(real_cost, 0)
114114
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, .9)
115-
npt.assert_equal(mask.sum(), 1840)
115+
npt.assert_equal(mask.sum(), 1800)
116116
npt.assert_equal(real_cost, 0.9)
117117

118118
def test_find_true_cost(self):

brainx/util.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def format_matrix2(data, s, sc, c, lk, co, idc=[],
116116
return ~(cmat == 0)
117117
return cmat
118118

119-
def threshold_adjacency_matrix(adj_matrix, cost):
119+
def threshold_adjacency_matrix(adj_matrix, cost, uptri=False):
120120
"""threshold adj_matrix at cost
121121
122122
Parameters
@@ -125,6 +125,8 @@ def threshold_adjacency_matrix(adj_matrix, cost):
125125
graph adjacency matrix
126126
cost : float
127127
user specified cost
128+
uptri : bool
129+
False returns symmetric matrix, True zeros out diagonal and below
128130
Returns
129131
-------
130132
thresholded : array of bools
@@ -138,7 +140,11 @@ def threshold_adjacency_matrix(adj_matrix, cost):
138140
lookup = make_cost_thresh_lookup(adj_matrix)
139141
cost_index = np.round(cost * float(nedges))
140142
thresh, expected_cost, round_cost = lookup[cost_index]
141-
return adj_matrix > thresh, expected_cost
143+
adj_matrix = adj_matrix > thresh #threshold matrix
144+
np.fill_diagonal(adj_matrix, 0) #zero out diagonal
145+
if uptri: #also zero out below diagonal
146+
adj_matrix = np.triu(adj_matrix)
147+
return adj_matrix, expected_cost
142148

143149
def find_true_cost(boolean_matrix):
144150
""" when passed a boolean matrix, presumably from thresholding to

0 commit comments

Comments
 (0)