Skip to content

Commit e57d9a0

Browse files
author
cindeem
committed
Merge pull request #11 from cindeem/util_edits
Utils updated to generate boolean arrays, code cleanup
2 parents d40473c + dd08691 commit e57d9a0

File tree

2 files changed

+396
-167
lines changed

2 files changed

+396
-167
lines changed

brainx/tests/test_util.py

Lines changed: 128 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#-----------------------------------------------------------------------------
44
# Imports
55
#-----------------------------------------------------------------------------
6-
6+
import unittest
77

88
# Third party
99
import nose.tools as nt
@@ -17,10 +17,118 @@
1717
# Functions
1818
#-----------------------------------------------------------------------------
1919

20+
def test_slice_data():
21+
subcond, blocks, subjects, nodes = 5, 10, 20, 4
22+
data_4d = np.ones((blocks, subjects, nodes, nodes))
23+
data_5d = np.ones((subcond, blocks, subjects, nodes, nodes))
24+
sym_4d = util.slice_data(data_4d, subjects - 1 , blocks - 1 )
25+
sym_5d = util.slice_data(data_5d, subjects -1 , blocks-1, subcond-1)
26+
npt.assert_equal(sym_4d.shape, (nodes, nodes))
27+
npt.assert_equal(sym_5d.shape, (nodes, nodes))
28+
npt.assert_raises(IndexError, util.slice_data, data_5d, subjects, blocks)
29+
30+
31+
def test_all_positive():
32+
jnk = np.random.random(40)
33+
npt.assert_equal(util.all_positive(jnk), True)
34+
# zeros counted as positive
35+
jnk[0] = 0
36+
npt.assert_equal(util.all_positive(jnk), True)
37+
# find real negative
38+
jnk = jnk - 0.5
39+
npt.assert_equal(util.all_positive(jnk), False)
40+
41+
def test_make_cost_thresh_lookup():
42+
adj_mat = np.zeros((10,10))
43+
ind = np.triu_indices(10,1)
44+
thresholds = np.linspace(.1, .8, 45)
45+
adj_mat[ind] = thresholds
46+
lookup = util.make_cost_thresh_lookup(adj_mat)
47+
48+
npt.assert_equal(sorted(thresholds, reverse=True), lookup.weight)
49+
npt.assert_equal(lookup[0].cost < lookup[-1].cost, True)
50+
# costs in ascending order
51+
## last vector is same as second vector rounded to 2 decimals
52+
npt.assert_almost_equal(lookup.actual_cost, lookup.cost, decimal=2)
53+
2054
def test_cost_size():
2155
n_nodes = 5
22-
npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes)
23-
56+
## NOTE DeprecationWarnings are ignored by default in 2.7
57+
#npt.assert_warns(UserWarning, util.cost_size, n_nodes)
58+
59+
60+
61+
class TestCost2Thresh(unittest.TestCase):
62+
def setUp(self):
63+
nnodes, nsub, nblocks, nsubblocks = 45, 20, 6, 2
64+
prng = np.random.RandomState(42)
65+
self.data_5d = prng.random_sample((nsubblocks, nblocks,
66+
nsub, nnodes, nnodes))
67+
ind = np.triu_indices(nnodes, k=1)
68+
nedges = (np.empty((nnodes, nnodes))[ind]).shape[0]
69+
costs, _, _ = util.cost_size(nnodes)
70+
self.nedges = nedges
71+
self.costs = costs
72+
self.lookup = np.zeros((nsubblocks, nblocks, nsub,2, nedges))
73+
bigcost =np.tile(costs[1:], nblocks*nsubblocks*nsub)
74+
bigcost.shape = (nsubblocks, nblocks, nsub, nedges)
75+
self.lookup[:,:,:,1,:] = bigcost
76+
for sblock in range(nsubblocks):
77+
for block in range(nblocks):
78+
for sid in range(nsub):
79+
tmp = self.data_5d[sblock, block, sid]
80+
self.lookup[sblock,block,sid,0,:] = sorted(tmp[ind],
81+
reverse=True)
82+
83+
def test_cost2thresh2(self):
84+
thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup)
85+
real_thr = self.lookup[0,0,0,0,100-1]
86+
npt.assert_almost_equal(thr, real_thr, decimal=7)
87+
88+
def test_cost2thresh(self):
89+
lookup = self.lookup[0].squeeze()
90+
thr = util.cost2thresh(self.costs[100],0,0,lookup)
91+
real_thr = lookup[0,0,0,100-1]# costs padded by zero
92+
npt.assert_almost_equal(thr, real_thr, decimal=7)
93+
94+
def test_format_matrix(self):
95+
bool_matrix = util.format_matrix2(self.data_5d, 0,0,0,
96+
self.lookup, self.costs[100])
97+
npt.assert_equal(bool_matrix.sum(), 100 -1)
98+
thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0,
99+
self.lookup, self.costs[100],asbool = False)
100+
npt.assert_equal(bool_matrix.sum()== thresh_matrix.sum(), False)
101+
npt.assert_almost_equal(thresh_matrix.sum(),
102+
94.183321784530804, decimal=7)
103+
## test format_matrix call on format_matrix2
104+
bool_matrix_sm = util.format_matrix(self.data_5d[0].squeeze(),
105+
0,0, self.lookup[0].squeeze(), self.costs[100])
106+
npt.assert_equal(bool_matrix.sum(), bool_matrix_sm.sum())
107+
108+
109+
def test_threshold_adjacency_matrix(self):
110+
adj_matrix = self.data_5d[0,0,0].squeeze()
111+
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, 0)
112+
npt.assert_equal(mask.sum(), 0)
113+
npt.assert_equal(real_cost, 0)
114+
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, .9)
115+
npt.assert_equal(mask.sum(), 1840)
116+
npt.assert_equal(real_cost, 0.9)
117+
118+
def test_find_true_cost(self):
119+
adj_matrix = self.data_5d[0,0,0].squeeze()
120+
mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, 0.2)
121+
true_cost = util.find_true_cost(mask)
122+
npt.assert_equal(real_cost, true_cost)
123+
## test on rounded array
124+
adj_matrix = self.data_5d[0,0,0].squeeze().round(decimals = 1)
125+
mask, expected_cost = util.threshold_adjacency_matrix(adj_matrix, 0.2)
126+
true_cost = util.find_true_cost(mask)
127+
## the cost of the thresholded matrix will be less than expected
128+
npt.assert_equal(real_cost > true_cost, True)
129+
130+
131+
24132

25133
def test_apply_cost():
26134
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],
@@ -105,4 +213,20 @@ def test_no_empty_modules():
105213
b[2] = []
106214
util.assert_no_empty_modules(a)
107215
nt.assert_raises(ValueError, util.assert_no_empty_modules, b)
108-
216+
217+
def test_rescale_arr():
218+
array = np.arange(5)
219+
scaled = util.rescale_arr(array, 3, 6)
220+
npt.assert_equal(scaled.min(), 3)
221+
scaled = util.rescale_arr(array, -10, 10)
222+
npt.assert_equal(scaled.min(), -10)
223+
npt.assert_equal(scaled.max(), 10)
224+
225+
def test_normalize():
226+
array = np.arange(5)
227+
result = util.normalize(array)
228+
npt.assert_equal(result.min(), 0)
229+
npt.assert_equal(result.max(), 1)
230+
npt.assert_raises(ValueError, util.normalize, array, 'blueberry', (0,2))
231+
232+

0 commit comments

Comments
 (0)