Skip to content

Commit 25e5345

Browse files
author
CindeeM
committed
NF: add tests and replace function with duplicate bevaiour with corresponding numpy functions
1 parent 313eeac commit 25e5345

File tree

2 files changed

+32
-42
lines changed

2 files changed

+32
-42
lines changed

brainx/tests/test_util.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,12 @@ def test_make_cost_thresh_lookup():
5353

5454
def test_cost_size():
5555
n_nodes = 5
56-
npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes)
57-
56+
## NOTE DeprecationWarnings are ignored by default in 2.7
57+
npt.assert_warns(UserWarning, util.cost_size, n_nodes)
58+
def test_test_warning():
59+
npt.assert_warns(UserWarning, util.test_warning)
60+
61+
5862
class TestCost2Thresh(unittest.TestCase):
5963
def setUp(self):
6064
nnodes, nsub, nblocks, nsubblocks = 45, 20, 6, 2
@@ -78,18 +82,20 @@ def setUp(self):
7882

7983
def test_cost2thresh2(self):
8084
thr = util.cost2thresh2(self.costs[100], 0,0,0,self.lookup)
81-
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
85+
real_thr = self.lookup[0,0,0,0,100-1]
86+
npt.assert_almost_equal(thr, real_thr, decimal=7)
8287

8388
def test_cost2thresh(self):
8489
lookup = self.lookup[0].squeeze()
85-
thr = util.cost2thresh(self.costs[100],0,0,0,lookup)
86-
npt.assert_almost_equal(thr, 0.24929222914887494, decimal=7)
90+
thr = util.cost2thresh(self.costs[100],0,0,lookup)
91+
real_thr = lookup[0,0,0,100-1]# costs padded by zero
92+
npt.assert_almost_equal(thr, real_thr, decimal=7)
8793

8894
def test_format_matrix(self):
8995
thresh_matrix = util.format_matrix2(self.data_5d, 0,0,0,
9096
self.lookup, self.costs[100])
9197
print thresh_matrix.sum()
92-
npt.assert_equal(thresh_matrix.sum(), 22)
98+
npt.assert_equal(thresh_matrix.sum(), 100 -1)
9399

94100
def test_apply_cost():
95101
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],

brainx/util.py

Lines changed: 20 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ def slice_data(data, sub, block, subcond=None):
3232
adjacency_matrix : numpy array
3333
symmetric numpy array (innode, nnode)
3434
"""
35-
if not subcond is None:
36-
return data[subcond, block, sub]
37-
return data[block, sub]
35+
if subcond is None:
36+
return data[block, sub]
37+
return data[subcond, block, sub]
3838

3939

4040
def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
@@ -102,7 +102,7 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
102102
cmat = np.triu(cmat,1)
103103

104104
# return boolean mask
105-
return cmat
105+
return ~(cmat == 0)
106106

107107
def threshold_adjacency_matrix(adj_matrix, cost):
108108
"""docstring for threshold_adjacency_matrix(adj_matrix, cost"""
@@ -160,12 +160,18 @@ def make_cost_thresh_lookup(adjacency_matrix):
160160
return lookup
161161

162162
def cost_size(nnodes):
163-
warnings.warn('deprecated: use make_cost_array', DeprecationWarning)
163+
"""create a list of actual costs, tot_edges, edges_short
164+
given a fixed number of nodes"""
165+
warnings.warn('this is no longer used: use make_cost_array')
166+
164167
tot_edges = 0.5 * nnodes * (nnodes - 1)
165168
costs = np.array(range(int(tot_edges) + 1), dtype=float) / tot_edges
166169
edges_short = tot_edges / 2
167170
return costs, tot_edges, edges_short
168171

172+
def test_warning():
173+
"""simple code to raise a warning"""
174+
warnings.warn('This is your warning')
169175

170176
def make_cost_array(n_nodes, cost=0.5):
171177
"""Make cost array of length cost * (the number of possible edges).
@@ -642,7 +648,7 @@ def cost2thresh2(cost, sub, axis1, axis0, lk,
642648
threshold : float
643649
threshold value for this cost"""
644650

645-
subject_lookup = slice_data(lk, sub, axis0, subcond=axis1)
651+
subject_lookup = slice_data(lk, sub, axis1, subcond=axis0)
646652
index = np.where(subject_lookup[1] == cost)
647653
threshold = subject_lookup[0][index]
648654

@@ -651,14 +657,14 @@ def cost2thresh2(cost, sub, axis1, axis0, lk,
651657
#if there are multiple thresholds, go down to the lower cost
652658
####Is this right?!!!####
653659
print('Subject %s has multiple thresholds at cost %s'%(sub, cost))
654-
print('index 1: %s, index 2: %s'%(c, sc))
660+
print('index 1: %s, index 2: %s'%(axis1, axis0))
655661
elif len(threshold) < 1:
656662
idc = idc-1
657663
newcost = costlist[idc]
658664
threshold = cost2thresh2(newcost, sub, axis1, axis0, lk,
659665
idc=idc, costlist = costlist)
660666
print(' '.join(['Subject %s does not have cost at %s'%(sub, cost),
661-
'index 1: %s, index 2: %s'%(c, sc),
667+
'index 1: %s, index 2: %s'%(axis1, axis0),
662668
'nearest cost %s being used'%(newcost)]))
663669
else:
664670
threshold = threshold[0]
@@ -811,24 +817,10 @@ def fill_diagonal(a,val):
811817
812818
See also
813819
--------
814-
- diag_indices: indices to access diagonals given shape information.
815-
- diag_indices_from: indices to access diagonals given an array.
820+
- numpy.diag_indices: indices to access diagonals given shape information.
821+
- numpy.diag_indices_from: indices to access diagonals given an array.
816822
"""
817-
if a.ndim < 2:
818-
raise ValueError("array must be at least 2-d")
819-
if a.ndim == 2:
820-
# Explicit, fast formula for the common case. For 2-d arrays, we
821-
# accept rectangular ones.
822-
step = a.shape[1] + 1
823-
else:
824-
# For more than d=2, the strided formula is only valid for arrays with
825-
# all dimensions equal, so we check first.
826-
if not np.alltrue(np.diff(a.shape)==0):
827-
raise ValueError("All dimensions of input must be of equal length")
828-
step = np.cumprod((1,)+a.shape[:-1]).sum()
829-
830-
# Write the value out into the diagonal.
831-
a.flat[::step] = val
823+
return np.fill_diagonal(a,val)
832824

833825

834826
def diag_indices(n,ndim=2):
@@ -881,11 +873,10 @@ def diag_indices(n,ndim=2):
881873
882874
See also
883875
--------
884-
- diag_indices_from: create the indices based on the shape of an existing
876+
- numpy.diag_indices_from: create the indices based on the shape of an existing
885877
array.
886878
"""
887-
idx = np.arange(n)
888-
return (idx,)*ndim
879+
return np.diag_indices(n, ndim=ndim)
889880

890881

891882
def diag_indices_from(arr):
@@ -897,16 +888,9 @@ def diag_indices_from(arr):
897888
----------
898889
arr : array, at least 2-d
899890
"""
900-
if not arr.ndim >= 2:
901-
raise ValueError("input array must be at least 2-d")
902-
# For more than d=2, the strided formula is only valid for arrays with
903-
# all dimensions equal, so we check first.
904-
if not np.alltrue(np.diff(a.shape)==0):
905-
raise ValueError("All dimensions of input must be of equal length")
891+
return np.diag_indices_from(arr)
906892

907-
return diag_indices(a.shape[0],a.ndim)
908893

909-
910894
def mask_indices(n,mask_func,k=0):
911895
"""Return the indices to access (n,n) arrays, given a masking function.
912896

0 commit comments

Comments
 (0)