Skip to content

Commit 28497f1

Browse files
author
CindeeM
committed
BF: make_cost_thresh_lookup has correct order of costs and associated thresholds
1 parent 9b514f6 commit 28497f1

File tree

2 files changed

+17
-5
lines changed

2 files changed

+17
-5
lines changed

brainx/tests/test_util.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,19 @@ def test_all_positive():
3838
jnk = jnk - 0.5
3939
npt.assert_equal(util.all_positive(jnk), False)
4040

41-
41+
def test_make_cost_thresh_lookup():
42+
adj_mat = np.zeros((10,10))
43+
ind = np.triu_indices(10,1)
44+
thresholds = np.linspace(.1, .8, 45)
45+
adj_mat[ind] = thresholds
46+
lookup = util.make_cost_thresh_lookup(adj_mat)
47+
48+
npt.assert_equal(sorted(thresholds, reverse=True), lookup[0,:])
49+
npt.assert_equal(lookup[1,0] < lookup[1,-1], True)
50+
# costs in ascending order
51+
## last vecore is same as second vector rounded to 2 decimals
52+
npt.assert_almost_equal(lookup[1], lookup[2], decimal=2)
53+
4254
def test_cost_size():
4355
n_nodes = 5
4456
npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes)

brainx/util.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,18 @@ def make_cost_thresh_lookup(adjacency_matrix):
130130
-------
131131
lookup : numpy array
132132
3 X number_of_edges, numpy array
133-
row 0 is sorted thresholds
134-
row 1 is cost at each threshold
133+
row 0 is sorted thresholds (largest -> smallest)
134+
row 1 is cost at each threshold (smallest -> largest)
135135
row 2 is costs rounded to one decimal point
136136
"""
137137

138138
ind = np.triu_indices_from(adjacency_matrix, k = 1)
139139
edges = adjacency_matrix[ind]
140140
nedges = edges.shape[0]
141141
lookup = np.zeros((3, nedges))
142-
lookup[0,:] = sorted(edges)
142+
lookup[0,:] = sorted(edges, reverse = True)
143143
lookup[1,:] = np.arange(nedges) / float(nedges)
144-
lookup[2,:] = np.round(lookup[1,:], decimals = 1)
144+
lookup[2,:] = np.round(lookup[1,:], decimals = 2)
145145
return lookup
146146

147147
def cost_size(nnodes):

0 commit comments

Comments
 (0)