@@ -32,7 +32,7 @@ def slice_data(data, sub, block, subcond=None):
32
32
adjacency_matrix : numpy array
33
33
symmetric numpy array (innode, nnode)
34
34
"""
35
- if subcond :
35
+ if not subcond is None :
36
36
return data [subcond , block , sub ]
37
37
return data [block , sub ]
38
38
@@ -91,16 +91,17 @@ def format_matrix2(data,s,sc,c,lk,co,idc = [],costlist=[],nouptri = False):
91
91
nouptri : bool
92
92
False zeros out diag and below, True returns symmetric matrix
93
93
"""
94
-
95
- cmat = data [sc ,c ,s ]
94
+ cmat = slice_data ( data , s , c , sc )
95
+ # cmat = data[sc,c,s]
96
96
th = cost2thresh2 (co ,s ,sc ,c ,lk ,[],idc ,costlist ) #get the right threshold
97
97
98
98
#cmat = replace_diag(cmat) #replace diagonals with zero
99
99
cmat = thresholded_arr (cmat ,th ,fill_val = 0 )
100
100
101
101
if not nouptri :
102
102
cmat = np .triu (cmat ,1 )
103
-
103
+
104
+ # return boolean mask
104
105
return cmat
105
106
106
107
def threshold_adjacency_matrix (adj_matrix , cost ):
@@ -606,32 +607,8 @@ def cost2thresh(cost, sub, bl, lk, idc=[], costlist=[]):
606
607
be registered.
607
608
608
609
"""
609
- return cost2thresh2 (cost , sub , bl , axis0 = None , lk = lk , last = None , idc = idc ,costlist = costlist )
610
- # For this subject and block, find the indices corresponding to this cost.
611
- # Note there may be more than one such index. There will be no such
612
- # indices if cost is not a value in the array.
613
- ind = np .where (lk [bl ][sub ][1 ] == cost )
614
- # The possibility of multiple (or no) indices implies multiple (or no)
615
- # thresholds may be acquired here.
616
- th = lk [bl ][sub ][0 ][ind ]
617
- n_thresholds = len (th )
618
- if n_thresholds > 1 :
619
- th = th [0 ]
620
- print ('' .join (['Subject %s has multiple thresholds in block %d ' ,
621
- 'corresponding to a cost of %f. The smallest is being' ,
622
- ' used.' ]) % (sub , bl , cost ))
623
- elif n_thresholds < 1 :
624
- idc = idc - 1
625
- newcost = costlist [idc ]
626
- th = cost2thresh (newcost , sub , bl , lk , idc , costlist )
627
- print ('' .join (['Subject %s does not have a threshold in block %d ' ,
628
- 'corresponding to a cost of %f. The threshold ' ,
629
- 'matching the nearest previous cost in costlist is ' ,
630
- 'being used.' ]) % (sub , block , cost ))
631
- else :
632
- th = th [0 ]
633
- return th
634
-
610
+ return cost2thresh2 (cost , sub , bl , axis0 = None ,
611
+ lk = lk , last = None , idc = idc ,costlist = costlist )
635
612
636
613
def cost2thresh2 (cost , sub , axis1 , axis0 , lk ,
637
614
last = None , idc = [], costlist = []):
@@ -667,7 +644,7 @@ def cost2thresh2(cost, sub, axis1, axis0, lk,
667
644
668
645
subject_lookup = slice_data (lk , sub , axis0 , subcond = axis1 )
669
646
index = np .where (subject_lookup [1 ] == cost )
670
- threshold = subject_lookup [0 ][ind ]
647
+ threshold = subject_lookup [0 ][index ]
671
648
672
649
if len (threshold ) > 1 :
673
650
threshold = threshold [0 ]
0 commit comments