|
3 | 3 | #-----------------------------------------------------------------------------
|
4 | 4 | # Imports
|
5 | 5 | #-----------------------------------------------------------------------------
|
6 |
| - |
| 6 | +import unittest |
7 | 7 |
|
8 | 8 | # Third party
|
9 | 9 | import nose.tools as nt
|
|
17 | 17 | # Functions
|
18 | 18 | #-----------------------------------------------------------------------------
|
19 | 19 |
|
| 20 | +def test_slice_data(): |
| 21 | + subcond, blocks, subjects, nodes = 5, 10, 20, 4 |
| 22 | + data_4d = np.ones((blocks, subjects, nodes, nodes)) |
| 23 | + data_5d = np.ones((subcond, blocks, subjects, nodes, nodes)) |
| 24 | + sym_4d = util.slice_data(data_4d, subjects - 1 , blocks - 1 ) |
| 25 | + sym_5d = util.slice_data(data_5d, subjects -1 , blocks-1, subcond-1) |
| 26 | + npt.assert_equal(sym_4d.shape, (nodes, nodes)) |
| 27 | + npt.assert_equal(sym_5d.shape, (nodes, nodes)) |
| 28 | + npt.assert_raises(IndexError, util.slice_data, data_5d, subjects, blocks) |
| 29 | + |
| 30 | + |
| 31 | +def test_all_positive(): |
| 32 | + jnk = np.random.random(40) |
| 33 | + npt.assert_equal(util.all_positive(jnk), True) |
| 34 | + # zeros counted as positive |
| 35 | + jnk[0] = 0 |
| 36 | + npt.assert_equal(util.all_positive(jnk), True) |
| 37 | + # find real negative |
| 38 | + jnk = jnk - 0.5 |
| 39 | + npt.assert_equal(util.all_positive(jnk), False) |
| 40 | + |
| 41 | +def test_make_cost_thresh_lookup(): |
| 42 | + adj_mat = np.zeros((10,10)) |
| 43 | + ind = np.triu_indices(10,1) |
| 44 | + thresholds = np.linspace(.1, .8, 45) |
| 45 | + adj_mat[ind] = thresholds |
| 46 | + lookup = util.make_cost_thresh_lookup(adj_mat) |
| 47 | + |
| 48 | + npt.assert_equal(sorted(thresholds, reverse=True), lookup.weight) |
| 49 | + npt.assert_equal(lookup[0].cost < lookup[-1].cost, True) |
| 50 | + # costs in ascending order |
| 51 | + ## last vector is same as second vector rounded to 2 decimals |
| 52 | + npt.assert_almost_equal(lookup.actual_cost, lookup.cost, decimal=2) |
| 53 | + |
20 | 54 | def test_cost_size():
|
21 | 55 | n_nodes = 5
|
22 |
| - npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes) |
23 |
| - |
| 56 | + ## NOTE DeprecationWarnings are ignored by default in 2.7 |
| 57 | + #npt.assert_warns(UserWarning, util.cost_size, n_nodes) |
| 58 | + |
| 59 | + |
| 60 | + |
| 61 | +class TestCost2Thresh(unittest.TestCase): |
| 62 | + def setUp(self): |
| 63 | + nnodes, nsub, nblocks, nsubblocks = 45, 20, 6, 2 |
| 64 | + prng = np.random.RandomState(42) |
| 65 | + self.data_5d = prng.random_sample((nsubblocks, nblocks, |
| 66 | + nsub, nnodes, nnodes)) |
| 67 | + ind = np.triu_indices(nnodes, k=1) |
| 68 | + nedges = (np.empty((nnodes, nnodes))[ind]).shape[0] |
| 69 | + costs, _, _ = util.cost_size(nnodes) |
| 70 | + self.nedges = nedges |
| 71 | + self.costs = costs |
| 72 | + self.lookup = np.zeros((nsubblocks, nblocks, nsub,2, nedges)) |
| 73 | + bigcost =np.tile(costs[1:], nblocks*nsubblocks*nsub) |
| 74 | + bigcost.shape = (nsubblocks, nblocks, nsub, nedges) |
| 75 | + self.lookup[:,:,:,1,:] = bigcost |
| 76 | + for sblock in range(nsubblocks): |
| 77 | + for block in range(nblocks): |
| 78 | + for sid in range(nsub): |
| 79 | + tmp = self.data_5d[sblock, block, sid] |
| 80 | + self.lookup[sblock,block,sid,0,:] = sorted(tmp[ind], |
| 81 | + reverse=True) |
| 82 | + |
| 83 | + def test_cost2thresh2(self): |
| 84 | + thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup) |
| 85 | + real_thr = self.lookup[0,0,0,0,100-1] |
| 86 | + npt.assert_almost_equal(thr, real_thr, decimal=7) |
| 87 | + |
| 88 | + def test_cost2thresh(self): |
| 89 | + lookup = self.lookup[0].squeeze() |
| 90 | + thr = util.cost2thresh(self.costs[100],0,0,lookup) |
| 91 | + real_thr = lookup[0,0,0,100-1]# costs padded by zero |
| 92 | + npt.assert_almost_equal(thr, real_thr, decimal=7) |
| 93 | + |
| 94 | + def test_format_matrix(self): |
| 95 | + bool_matrix = util.format_matrix2(self.data_5d, 0,0,0, |
| 96 | + self.lookup, self.costs[100]) |
| 97 | + npt.assert_equal(bool_matrix.sum(), 100 -1) |
| 98 | + thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0, |
| 99 | + self.lookup, self.costs[100],asbool = False) |
| 100 | + npt.assert_equal(bool_matrix.sum()== thresh_matrix.sum(), False) |
| 101 | + npt.assert_almost_equal(thresh_matrix.sum(), |
| 102 | + 94.183321784530804, decimal=7) |
| 103 | + ## test format_matrix call on format_matrix2 |
| 104 | + bool_matrix_sm = util.format_matrix(self.data_5d[0].squeeze(), |
| 105 | + 0,0, self.lookup[0].squeeze(), self.costs[100]) |
| 106 | + npt.assert_equal(bool_matrix.sum(), bool_matrix_sm.sum()) |
| 107 | + |
| 108 | + |
| 109 | + def test_threshold_adjacency_matrix(self): |
| 110 | + adj_matrix = self.data_5d[0,0,0].squeeze() |
| 111 | + mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, 0) |
| 112 | + npt.assert_equal(mask.sum(), 0) |
| 113 | + npt.assert_equal(real_cost, 0) |
| 114 | + mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, .9) |
| 115 | + npt.assert_equal(mask.sum(), 1840) |
| 116 | + npt.assert_equal(real_cost, 0.9) |
| 117 | + |
| 118 | + def test_find_true_cost(self): |
| 119 | + adj_matrix = self.data_5d[0,0,0].squeeze() |
| 120 | + mask, real_cost = util.threshold_adjacency_matrix(adj_matrix, 0.2) |
| 121 | + true_cost = util.find_true_cost(mask) |
| 122 | + npt.assert_equal(real_cost, true_cost) |
| 123 | + ## test on rounded array |
| 124 | + adj_matrix = self.data_5d[0,0,0].squeeze().round(decimals = 1) |
| 125 | + mask, expected_cost = util.threshold_adjacency_matrix(adj_matrix, 0.2) |
| 126 | + true_cost = util.find_true_cost(mask) |
| 127 | + ## the cost of the thresholded matrix will be less than expected |
| 128 | + npt.assert_equal(real_cost > true_cost, True) |
| 129 | + |
| 130 | + |
| 131 | + |
24 | 132 |
|
25 | 133 | def test_apply_cost():
|
26 | 134 | corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],
|
@@ -105,4 +213,20 @@ def test_no_empty_modules():
|
105 | 213 | b[2] = []
|
106 | 214 | util.assert_no_empty_modules(a)
|
107 | 215 | nt.assert_raises(ValueError, util.assert_no_empty_modules, b)
|
108 |
| - |
| 216 | + |
| 217 | +def test_rescale_arr(): |
| 218 | + array = np.arange(5) |
| 219 | + scaled = util.rescale_arr(array, 3, 6) |
| 220 | + npt.assert_equal(scaled.min(), 3) |
| 221 | + scaled = util.rescale_arr(array, -10, 10) |
| 222 | + npt.assert_equal(scaled.min(), -10) |
| 223 | + npt.assert_equal(scaled.max(), 10) |
| 224 | + |
| 225 | +def test_normalize(): |
| 226 | + array = np.arange(5) |
| 227 | + result = util.normalize(array) |
| 228 | + npt.assert_equal(result.min(), 0) |
| 229 | + npt.assert_equal(result.max(), 1) |
| 230 | + npt.assert_raises(ValueError, util.normalize, array, 'blueberry', (0,2)) |
| 231 | + |
| 232 | + |
0 commit comments