Skip to content

Commit 9b514f6

Browse files
committed
NF: start cleanup of redundant code in util
1 parent 090de86 commit 9b514f6

File tree

2 files changed

+180
-80
lines changed

2 files changed

+180
-80
lines changed

brainx/tests/test_util.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,17 @@ def test_slice_data():
2828
npt.assert_raises(IndexError, util.slice_data, data_5d, subjects, blocks)
2929

3030

31+
def test_all_positive():
32+
jnk = np.random.random(40)
33+
npt.assert_equal(util.all_positive(jnk), True)
34+
# zeros counted as positive
35+
jnk[0] = 0
36+
npt.assert_equal(util.all_positive(jnk), True)
37+
# find real negative
38+
jnk = jnk - 0.5
39+
npt.assert_equal(util.all_positive(jnk), False)
40+
41+
3142
def test_cost_size():
3243
n_nodes = 5
3344
npt.assert_warns(DeprecationWarning, util.cost_size, n_nodes)
@@ -116,4 +127,20 @@ def test_no_empty_modules():
116127
b[2] = []
117128
util.assert_no_empty_modules(a)
118129
nt.assert_raises(ValueError, util.assert_no_empty_modules, b)
119-
130+
131+
def test_rescale_arr():
132+
array = np.arange(5)
133+
scaled = util.rescale_arr(array, 3, 6)
134+
npt.assert_equal(scaled.min(), 3)
135+
scaled = util.rescale_arr(array, -10, 10)
136+
npt.assert_equal(scaled.min(), -10)
137+
npt.assert_equal(scaled.max(), 10)
138+
139+
def test_normalize():
140+
array = np.arange(5)
141+
result = util.normalize(array)
142+
npt.assert_equal(result.min(), 0)
143+
npt.assert_equal(result.max(), 1)
144+
npt.assert_raises(ValueError, util.normalize, array, 'blueberry', (0,2))
145+
146+

brainx/util.py

Lines changed: 152 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@
44
#-----------------------------------------------------------------------------
55
# Imports
66
#-----------------------------------------------------------------------------
7-
from __future__ import print_function
87

98
import warnings
10-
119
import numpy as np
1210
import networkx as nx
1311

@@ -40,15 +38,22 @@ def slice_data(data, sub, block, subcond=None):
4038

4139

4240
def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
43-
""" Function which formats matrix for a particular subject and particular block (thresholds, upper-tris it) so that we can make a graph object out of it
41+
""" Function which thresholds the adjacency matrix for a particular
42+
subject and particular block, using lookuptable to find thresholds,
43+
cost value to find threshold, costlist
44+
(thresholds, upper-tris it) so that we can use it with simulated annealing
4445
45-
Parameters:
46+
Parameters
4647
-----------
47-
data = full data array
48+
data : full data array 4D (block, sub, node, node)
4849
s = subject
4950
b = block
5051
lk = lookup table for study
5152
co = cost value to threshold at
53+
idc = index of ideal cost
54+
costlist = list (size num_edges) with ordered values used to find
55+
threshold to control number of edges
56+
nouptri = if False only keeps upper tri, True yields symmetric matrix
5257
"""
5358

5459
cmat = data[b,s]
@@ -65,15 +70,28 @@ def format_matrix(data,s,b,lk,co,idc = [],costlist=[],nouptri = False):
6570
def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
6671
""" Function which formats matrix for a particular subject and particular block (thresholds, upper-tris it) so that we can make a graph object out of it
6772
68-
Parameters:
69-
-----------
70-
data = full data array
71-
s = subject
72-
b = block
73-
lk = lookup table for study
74-
co = cost value to threshold at
73+
Parameters
74+
----------
75+
data : numpy array
76+
full data array 5D (subcondition, condition, subject, node, node)
77+
s : int
78+
index of subject
79+
sc : int
80+
index of sub condition
81+
c : int
82+
index of condition
83+
lk : numpy array
84+
lookup table for thresholds at each possible cost
85+
co : float
86+
cost value to threshold at
87+
idc : float
88+
ideal cost
89+
costlist : list
90+
list of possible costs
91+
nouptri : bool
92+
False zeros out diag and below, True returns symmetric matrix
7593
"""
76-
94+
7795
cmat = data[sc,c,s]
7896
th = cost2thresh2(co,s,sc,c,lk,[],idc,costlist) #get the right threshold
7997

@@ -85,6 +103,46 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
85103

86104
return cmat
87105

106+
def threshold_adjacency_matrix(adj_matrix, cost):
107+
"""docstring for threshold_adjacency_matrix(adj_matrix, cost"""
108+
109+
pass
110+
111+
112+
def all_positive(adjacency_matrix):
113+
""" checks if edge value sin adjacency matrix are all positive
114+
or positive and negative
115+
Returns
116+
-------
117+
all_positive : bool
118+
True if all values are >=0
119+
False if values < 0
120+
"""
121+
# add 1 so 0-> 1(True) , -1 -> 0 False
122+
signs = set( np.sign(adjacency_matrix) + 1 )
123+
return bool(sorted(signs)[0])
124+
125+
126+
def make_cost_thresh_lookup(adjacency_matrix):
127+
"""takes upper triangular (offset 1, no diagonal) of summetric
128+
adjacency matrix, sorts (lowest -> highest)
129+
Returns
130+
-------
131+
lookup : numpy array
132+
3 X number_of_edges, numpy array
133+
row 0 is sorted thresholds
134+
row 1 is cost at each threshold
135+
row 2 is costs rounded to one decimal point
136+
"""
137+
138+
ind = np.triu_indices_from(adjacency_matrix, k = 1)
139+
edges = adjacency_matrix[ind]
140+
nedges = edges.shape[0]
141+
lookup = np.zeros((3, nedges))
142+
lookup[0,:] = sorted(edges)
143+
lookup[1,:] = np.arange(nedges) / float(nedges)
144+
lookup[2,:] = np.round(lookup[1,:], decimals = 1)
145+
return lookup
88146

89147
def cost_size(nnodes):
90148
warnings.warn('deprecated: use make_cost_array', DeprecationWarning)
@@ -125,9 +183,17 @@ def make_cost_array(n_nodes, cost=0.5):
125183
costs = np.array(range(int(tot_edges * cost)), dtype=float) / tot_edges
126184
return costs, tot_edges
127185

186+
def metrics_to_pandas():
187+
"""docstring for metrics_to_pandas"""
188+
pass
128189

129190
def store_metrics(b, s, co, metd, arr):
130-
"""Store a set of metrics into a structured array"""
191+
"""Store a set of metrics into a structured array
192+
b = block
193+
s = subject
194+
co = cost? float
195+
metd = dict of metrics
196+
arr : array?"""
131197

132198
if arr.ndim == 3:
133199
idx = b,s,co
@@ -148,6 +214,8 @@ def regular_lattice(n,k):
148214
149215
This type of graph is the starting point for the Watts-Strogatz small-world
150216
model, where connections are then rewired in a second phase.
217+
218+
XXX TODO Use as comparison, check networkx to see if its update worth redundancy
151219
"""
152220
# Code simplified from the networkx.watts_strogatz_graph one
153221
G = nx.Graph()
@@ -282,18 +350,23 @@ def normalize(arr,mode='direct',folding_edges=None):
282350
"""Normalize an array to [0,1] range.
283351
284352
By default, this simply rescales the input array to [0,1]. But it has a
285-
special 'folding' mode that allows for the normalization of an array with
286-
negative and positive values by mapping the negative values to their
287-
flipped sign
353+
special 'folding' mode that allong absolute value of all values, in addition
354+
values between the folding_edges (low_cutoff, high_cutoff) will be zeroed.
288355
289356
Parameters
290357
----------
291358
arr : 1d array
292-
359+
assumes dtype == float, if int32, will raise ValueError
360+
293361
mode : string, one of ['direct','folding']
362+
if direct rescale all values (pos and neg) between 0,1
363+
if folding, zeros out values between folding_values (inclusive)
364+
and normalizes absolute value of remaining values
294365
295366
folding_edges : (float,float)
296-
Only needed for folding mode, ignored in 'direct' mode.
367+
(low_cutoff, high_cutoff) lower and upper values to zero out
368+
(values are inclusive)
369+
Only needed for folding mode, ignored in 'direct' mode.
297370
298371
Examples
299372
--------
@@ -315,37 +388,23 @@ def normalize(arr,mode='direct',folding_edges=None):
315388
>>> c
316389
array([-0.8 , -0.6333, -0.4667, -0.3 , 0.3 , 0.4333, 0.5667, 0.7 ])
317390
>>> normalize(c,'folding',[-0.3,0.3])
318-
array([ 1. , 0.6667, 0.3333, 0. , 0. , 0.2667, 0.5333, 0.8 ])
391+
array([ 1. , 0.7917, 0.5833, 0. , 0. , 0.5417, 0.7083, 0.875 ])
319392
"""
320393
if mode == 'direct':
321394
return rescale_arr(arr,0,1)
322-
else:
323-
fa, fb = folding_edges
395+
elif mode == 'folding':
396+
# cast folding_edges to floats in case inputs are ints
397+
low_cutoff, high_cutoff = [float(x) for x in folding_edges]
324398
amin, amax = arr.min(), arr.max()
325-
ra,rb = float(fa-amin),float(amax-fb) # in case inputs are ints
326-
if ra<0 or rb<0:
399+
low_diff, high_diff = low_cutoff-amin, amax-high_cutoff
400+
if low_diff < 0 or high_diff < 0:
327401
raise ValueError("folding edges must be within array range")
328-
greater = arr>= fb
329-
upper_idx = greater.nonzero()
330-
lower_idx = (~greater).nonzero()
331-
# Two folding scenarios, we map the thresholds to zero but the upper
332-
# ranges must retain comparability.
333-
if ra > rb:
334-
lower = 1.0 - rescale_arr(arr[lower_idx],0,1.0)
335-
upper = rescale_arr(arr[upper_idx],0,float(rb)/ra)
336-
else:
337-
upper = rescale_arr(arr[upper_idx],0,1)
338-
# The lower range is trickier: we need to rescale it and then flip
339-
# it, so the edge goes to 0.
340-
resc_a = float(ra)/rb
341-
lower = rescale_arr(arr[lower_idx],0,resc_a)
342-
lower = resc_a - lower
343-
# Now, make output array
344-
out = np.empty_like(arr)
345-
out[lower_idx] = lower
346-
out[upper_idx] = upper
347-
return out
348-
402+
mask = np.logical_and( arr >= low_cutoff, arr <= high_cutoff)
403+
out = arr.copy()
404+
out[mask] = 0
405+
return rescale_arr(np.abs(out), 0, 1)
406+
else:
407+
raise ValueError('Unknown mode %s: valid options("direct", "folding")')
349408

350409
def mat2graph(cmat,threshold=0.0,threshold2=None):
351410
"""Make a weighted graph object out of an adjacency matrix.
@@ -559,45 +618,59 @@ def cost2thresh(cost, sub, bl, lk, idc=[], costlist=[]):
559618
return th
560619

561620

562-
def cost2thresh2(cost,sub,sc,c,lk,last,idc = [],costlist=[]):
563-
"""A definition for loading the lookup table and finding the threshold associated with a particular cost for a particular subject in a particular block
621+
def cost2thresh2(cost, sub, sc, c, lk, last = None, idc = [], costlist=[]):
622+
"""A definition for loading the lookup table and finding the threshold
623+
associated with a particular cost for a particular subject in a
624+
particular block of data
564625
565-
inputs:
566-
cost: cost value for which we need the associated threshold
567-
sub: subject number
568-
bl: block number
569-
lk: lookup table (block x subject x cost
570-
last: last threshold value
571-
572-
output:
573-
th: threshold value for this cost"""
574-
575-
#print cost,sub,bl
626+
Inputs
627+
------
628+
cost : float
629+
cost value for which we need the associated threshold
630+
sub : int
631+
(axis -2) subject number
632+
axis1 : int
633+
axis 1 into lookup (eg block number or condition)
634+
axis0 : int
635+
axis 0 into lookup (eg subcondition)
636+
lk : numpy array
637+
lookup table (axis0 x axis1 x subject x 2 )
638+
last : None
639+
NOT USED last threshold value
640+
idc : int or empty list
641+
Index in costlist corresponding to cost currently being
642+
processed. By default, idc is an empty list.
643+
costlist : array-like
644+
List of costs that are being queried with the current function
645+
in order.
576646
577-
ind=np.where(lk[sc,c,sub][1]==cost)
578-
th=lk[sc,c,sub][0][ind]
647+
Returns
648+
-------
649+
threshold : float
650+
threshold value for this cost"""
651+
652+
subject_lookup = slice_data(lk, sub, c, subcond=sc)
653+
index = np.where(subject_lookup[1] == cost)
654+
threshold = subject_lookup[0][ind]
579655

580-
if len(th)>1:
581-
th=th[0] #if there are multiple thresholds, go down to the lower cost ####Is this right?!!!####
582-
print('multiple thresh')
583-
elif len(th)<1:
584-
done = 1
585-
while done:
586-
idc = idc-1
587-
newcost = costlist[idc]
588-
print(idc,newcost)
589-
ind=np.where(lk[bl][sub][1]==newcost)
590-
th=lk[bl][sub][0][ind]
591-
if len(th) > 1:
592-
th = th[0]
593-
done = 0
594-
#th=last #if there is no associated thresh value because of repeats, just use the previous one
595-
print('use previous thresh')
656+
if len(threshold) > 1:
657+
threshold = threshold[0]
658+
#if there are multiple thresholds, go down to the lower cost
659+
####Is this right?!!!####
660+
print('Subject %s has multiple thresholds at cost %s'%(sub, cost))
661+
print('index 1: %s, index 2: %s'%(c, sc))
662+
elif len(threshold) < 1:
663+
idc = idc-1
664+
newcost = costlist[idc]
665+
threshold = cost2thresh2(newcost, sub, sc, c, lk,
666+
idc=idc, costlist = costlist)
667+
print(' '.join(['Subject %s does not have cost at %s'%(sub, cost),
668+
'index 1: %s, index 2: %s'%(c, sc),
669+
'nearest cost %s being used'%(newcost)]))
596670
else:
597-
th=th[0]
671+
threshold = threshold[0]
598672

599-
#print th
600-
return th
673+
return threshold
601674

602675

603676
def apply_cost(corr_mat, cost, tot_edges):

0 commit comments

Comments
 (0)