Skip to content

Commit 9e392c2

Browse files
committed
Write apply_cost and its test.
1 parent e6d5ec1 commit 9e392c2

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed

brainx/tests/test_util.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,42 @@
1616
#-----------------------------------------------------------------------------
1717
# Functions
1818
#-----------------------------------------------------------------------------
19+
def test_apply_cost():
20+
corr_mat = np.array([[0.0, 0.5, 0.3, 0.2, 0.1],
21+
[0.5, 0.0, 0.4, 0.1, 0.2],
22+
[0.3, 0.4, 0.0, 0.7, 0.2],
23+
[0.2, 0.1, 0.7, 0.0, 0.4],
24+
[0.1, 0.2, 0.2, 0.4, 0.0]])
25+
# A five-node undirected graph has ten possible edges. Thus, the result
26+
# here should be a graph with five edges.
27+
possible_edges = 10
28+
cost = 0.5
29+
thresholded_corr_mat, threshold = util.apply_cost(corr_mat, cost,
30+
possible_edges)
31+
nt.assert_true(np.allclose(thresholded_corr_mat,
32+
np.array([[0.0, 0.0, 0.0, 0.0, 0.0],
33+
[0.5, 0.0, 0.0, 0.0, 0.0],
34+
[0.3, 0.4, 0.0, 0.0, 0.0],
35+
[0.0, 0.0, 0.7, 0.0, 0.0],
36+
[0.0, 0.0, 0.0, 0.4, 0.0]])))
37+
nt.assert_almost_equal(threshold, 0.3)
38+
# Check the case in which cost requires that one of several identical edges
39+
# be kept and the others removed. apply_cost should keep all of these
40+
# identical edges.
41+
#
42+
# To test this, I need to update only a value in the lower triangle. The
43+
# function zeroes out the upper triangle immediately.
44+
corr_mat[2, 0] = 0.2
45+
thresholded_corr_mat, threshold = util.apply_cost(corr_mat, cost,
46+
possible_edges)
47+
nt.assert_true(np.allclose(thresholded_corr_mat,
48+
np.array([[0.0, 0.0, 0.0, 0.0, 0.0],
49+
[0.5, 0.0, 0.0, 0.0, 0.0],
50+
[0.2, 0.4, 0.0, 0.0, 0.0],
51+
[0.2, 0.0, 0.7, 0.0, 0.0],
52+
[0.0, 0.2, 0.2, 0.4, 0.0]])))
53+
nt.assert_almost_equal(threshold, 0.2)
54+
1955

2056
def assert_graphs_equal(g,h):
2157
"""Trivial 'equality' check for graphs"""

brainx/util.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,57 @@ def cost2thresh2(cost,sub,sc,c,lk,last,idc = [],costlist=[]):
565565
#print th
566566
return th
567567

568+
569+
def apply_cost(corr_mat, cost, tot_edges):
570+
"""Threshold corr_mat to achieve cost.
571+
572+
Return the thresholded matrix and the threshold value. In the
573+
thresholded matrix, the main diagonal and upper triangle are set to
574+
0, so information is held only in the lower triangle.
575+
576+
Parameters
577+
----------
578+
corr_mat: array_like
579+
Square matrix with ROI-to-ROI correlations.
580+
581+
cost: float
582+
Fraction of all possible undirected edges desired in the
583+
thresholded matrix.
584+
585+
tot_edges: integer
586+
The number of possible undirected edges in a graph with the
587+
number of nodes in corr_mat.
588+
589+
Returns
590+
-------
591+
thresholded_mat: array_like
592+
Square matrix with correlations below threshold set to 0,
593+
making the fraction of matrix elements that are non-zero equal
594+
to cost. In addition, the main diagonal and upper triangle are
595+
set to 0.
596+
597+
threshold: float
598+
Correlations below this value have been set to 0 in
599+
thresholded_corr_mat.
600+
601+
Notes
602+
-----
603+
If not all correlations are unique, it is possible that there will
604+
be no way to achieve the cost without, e.g., arbitrarily removing
605+
one of two identical correlations while keeping the other. Instead
606+
of making such an arbitrary choice, this function retains all
607+
identical correlations equal to or greater than threshold, even if
608+
this means cost is not exactly achieved.
609+
610+
"""
611+
thresholded_mat = np.tril(corr_mat, -1)
612+
n_nonzero = cost * tot_edges
613+
elements = thresholded_mat.ravel()
614+
threshold = elements[elements.argsort()[-n_nonzero]]
615+
thresholded_mat[thresholded_mat < threshold] = 0
616+
return thresholded_mat, threshold
617+
618+
568619
def network_ind(ntwk_type,n_nodes):
569620
"""Reads in a network type, number of nodes total and returns the indices of that network"""
570621

0 commit comments

Comments
 (0)