1+ '''
2+ agg_wt_avg.py
3+
4+ Aggregate a matrix of replicate profiles into a single signature using
5+ a weighted average based on the correlation between replicates. That is, if
6+ one replicate is less correlated with the other replicates, its values will
7+ not be weighted as highly in the aggregated signature.
8+
9+ Equivalent to the 'modz' method in mortar.
10+ '''
11+
12+ import numpy as np
13+
14+ rounding_precision = 4
15+
16+
17+ def get_upper_triangle (correlation_matrix ):
18+ ''' Extract upper triangle from a square matrix. Negative values are
19+ set to 0.
20+
21+ Args:
22+ correlation_matrix (pandas df): Correlations between all replicates
23+
24+ Returns:
25+ upper_tri_df (pandas df): Upper triangle extracted from
26+ correlation_matrix; rid is the row index, cid is the column index,
27+ corr is the extracted correlation value
28+ '''
29+ upper_triangle = correlation_matrix .where (np .triu (np .ones (correlation_matrix .shape ), k = 1 ).astype (np .bool ))
30+
31+ # convert matrix into long form description
32+ upper_tri_df = upper_triangle .stack ().reset_index (level = 1 )
33+ upper_tri_df .columns = ['rid' , 'corr' ]
34+
35+ # Index at this point is cid, it now becomes a column
36+ upper_tri_df .reset_index (level = 0 , inplace = True )
37+
38+ # Get rid of negative values
39+ upper_tri_df ['corr' ] = upper_tri_df ['corr' ].clip (lower = 0 )
40+
41+ return upper_tri_df .round (rounding_precision )
42+
43+
44+ def calculate_weights (correlation_matrix , min_wt ):
45+ ''' Calculate a weight for each profile based on its correlation to other
46+ replicates. Negative correlations are clipped to 0, and weights are clipped
47+ to be min_wt at the least.
48+
49+ Args:
50+ correlation_matrix (pandas df): Correlations between all replicates
51+ min_wt (float): Minimum raw weight when calculating weighted average
52+
53+ Returns:
54+ raw weights (pandas series): Mean correlation to other replicates
55+ weights (pandas series): raw_weights normalized such that they add to 1
56+ '''
57+ # fill diagonal of correlation_matrix with np.nan
58+ np .fill_diagonal (correlation_matrix .values , np .nan )
59+
60+ # remove negative values
61+ correlation_matrix = correlation_matrix .clip (lower = 0 )
62+
63+ # get average correlation for each profile (will ignore NaN)
64+ raw_weights = correlation_matrix .mean (axis = 1 )
65+
66+ # threshold weights
67+ raw_weights = raw_weights .clip (lower = min_wt )
68+
69+ # normalize raw_weights so that they add to 1
70+ weights = raw_weights / sum (raw_weights )
71+
72+ return raw_weights .round (rounding_precision ), weights .round (rounding_precision )
73+
74+
75+ def agg_wt_avg (mat , min_wt = 0.01 , corr_metric = 'spearman' ):
76+ ''' Aggregate a set of replicate profiles into a single signature using
77+ a weighted average.
78+
79+ Args:
80+ mat (pandas df): a matrix of replicate profiles, where the columns are
81+ samples and the rows are features; columns correspond to the
82+ replicates of a single perturbagen
83+ min_wt (float): Minimum raw weight when calculating weighted average
84+ corr_metric (string): Spearman or Pearson; the correlation method
85+
86+ Returns:
87+ out_sig (pandas series): weighted average values
88+ upper_tri_df (pandas df): the correlations between each profile that went into the signature
89+ raw weights (pandas series): weights before normalization
90+ weights (pandas series): weights after normalization
91+ '''
92+ assert mat .shape [1 ] > 0 , "mat is empty! mat: {}" .format (mat )
93+
94+ if mat .shape [1 ] == 1 :
95+
96+ out_sig = mat
97+ upper_tri_df = None
98+ raw_weights = None
99+ weights = None
100+
101+ else :
102+
103+ assert corr_metric in ["spearman" , "pearson" ]
104+
105+ # Make correlation matrix column wise
106+ corr_mat = mat .corr (method = corr_metric )
107+
108+ # Save the values in the upper triangle
109+ upper_tri_df = get_upper_triangle (corr_mat )
110+
111+ # Calculate weight per replicate
112+ raw_weights , weights = calculate_weights (corr_mat , min_wt )
113+
114+ # Apply weights to values
115+ weighted_values = mat * weights
116+ out_sig = weighted_values .sum (axis = 1 )
117+
118+ return out_sig , upper_tri_df , raw_weights , weights
0 commit comments