11import numpy as np
22import pandas as pd
33import scanpy as sc
4+ from scipy .stats import fisher_exact
45from sklearn .metrics import pairwise_distances
56from sklearn .neighbors import NearestNeighbors
7+ from statsmodels .stats .multitest import fdrcorrection
68
9+ from polygraph .classifier import groupwise_svm
710from polygraph .stats import groupwise_fishers , groupwise_mann_whitney , kruskal_dunn
811
912
1013def embedding_pca (ad , ** kwargs ):
1114 """
12- Perform PCA on sequence embeddings
15+ Perform PCA on sequence embeddings.
1316
1417 Args:
1518 ad (anndata.AnnData): Anndata object containing sequence embeddings
@@ -21,10 +24,7 @@ def embedding_pca(ad, **kwargs):
2124 """
2225 sc .pp .pca (ad , ** kwargs )
2326 print (
24- "Fraction of variance explained: " , np .round (ad .uns ["pca" ]["variance_ratio" ], 2 )
25- )
26- print (
27- "Fraction of total variance explained: " ,
27+ "Fraction of total variance explained by PCA components: " ,
2828 np .sum (ad .uns ["pca" ]["variance_ratio" ]),
2929 )
3030 return ad
@@ -89,12 +89,10 @@ def differential_analysis(ad, reference_group, group_col="Group"):
8989 return ad
9090
9191
92- def one_nn_stats (ad , reference_group , group_col = "Group" , use_pca = False ):
92+ def groupwise_1nn (ad , reference_group , group_col = "Group" , use_pca = False ):
9393 """
94- Calculate the following 1-nearest neighbor statistics based on the
95- sequence embeddings:
96- 1. Group ID of nearest neighbor
97- 2. Distance to nearest neighbor
94+ For each sequence, find its nearest neighbor among its own group or
95+ the reference group based on the sequence embeddings.
9896
9997 Args:
10098 ad (anndata.AnnData): Anndata object containing sequence embeddings of
@@ -109,7 +107,92 @@ def one_nn_stats(ad, reference_group, group_col="Group", use_pca=False):
109107 probability table of nearest neighbor groups for each group
110108 in .uns
111109 """
110+ res = pd .DataFrame ()
111+
112+ # Get reference embedding
113+ in_ref = ad .obs [group_col ] == reference_group
114+
115+ # List groups
116+ groups = ad .obs [group_col ].unique ()
117+
118+ # List nonreference groups
119+ nonreference_groups = groups [groups != reference_group ]
120+
121+ for group in nonreference_groups :
122+ # Get group embeddings
123+ in_group = ad .obs [group_col ] == group
124+ in_group_or_ref = (in_ref | in_group ).tolist ()
125+ in_group_or_ref_indices = ad .obs .index [in_group_or_ref ].tolist ()
126+ if use_pca :
127+ X = ad .obsm ["X_pca" ][in_group_or_ref , :]
128+ else :
129+ X = ad .X [in_group_or_ref , :]
130+
131+ # Calculate nearest neighbors
132+ nbrs = NearestNeighbors (n_neighbors = 2 , algorithm = "ball_tree" ).fit (X )
133+ distances , indices = nbrs .kneighbors (X )
134+ indices = np .array (in_group_or_ref_indices )[indices [:, 1 ]]
135+
136+ # Record whether nearest neighbor is a member of the reference group
137+ ad .obs .loc [in_group_or_ref , f"{ group } _one_nn_idx" ] = indices
138+ ad .obs .loc [in_group_or_ref , f"{ group } _one_nn_dist" ] = distances [:, 1 ]
139+ ad .obs .loc [in_group_or_ref , f"{ group } _one_nn_group" ] = ad .obs .loc [
140+ indices , group_col
141+ ].tolist ()
142+
143+ # Make contingency table
144+ cont = ad .obs [[group_col , f"{ group } _one_nn_group" ]].value_counts ()
145+ cont = (
146+ cont .unstack ()
147+ .fillna (0 )
148+ .loc [[group , reference_group ], [group , reference_group ]]
149+ .values
150+ )
151+
152+ # Perform tests
153+ group_prop = cont [0 , 1 ] / cont [0 , :].sum ()
154+ ref_prop = cont [1 , 1 ] / cont [1 , :].sum ()
155+
156+ res = pd .concat (
157+ [
158+ res ,
159+ pd .DataFrame (
160+ {
161+ group_col : [group ],
162+ "group_prop" : [group_prop ],
163+ "ref_prop" : [ref_prop ],
164+ "pval" : [fisher_exact (cont , alternative = "two-sided" ).pvalue ],
165+ }
166+ ),
167+ ]
168+ )
169+
170+ # Save results
171+ res = res .set_index (group_col )
172+ res ["padj" ] = fdrcorrection (res .pval )[1 ]
173+ ad .uns ["1NN_group_probs" ] = res .iloc [:, :2 ].copy ()
174+ ad .uns ["1NN_ref_prop_test" ] = res
175+ return ad
176+
177+
178+ def joint_1nn (ad , reference_group , group_col = "Group" , use_pca = False ):
179+ """
180+ Find the group ID of each sequence's 1-nearest neighbor statistics based on the
181+ sequence embeddings. Compare all groups to all other groups.
182+
183+ Args:
184+ ad (anndata.AnnData): Anndata object containing sequence embeddings of
185+ shape (n_seqs x n_vars)
186+ reference_group (str): ID of group to use as reference
187+ group_col (str): Name of column in .obs containing group ID
188+ use_pca (bool): Whether to use PCA distances
112189
190+ Returns:
191+ ad (anndata.AnnData): Modified anndata object containing index, distance and
192+ group ID of each sequence's nearest neighbor in .obs, as well as a
193+ probability table of nearest neighbor groups for each group
194+ in .uns
195+ """
113196 # Get nearest neighbor for each sequence
114197 if use_pca :
115198 nbrs = NearestNeighbors (n_neighbors = 2 , algorithm = "ball_tree" ).fit (
@@ -124,7 +207,7 @@ def one_nn_stats(ad, reference_group, group_col="Group", use_pca=False):
124207 ad .obs .loc [:, "one_nn_idx" ] = indices [:, 1 ]
125208 ad .obs .loc [:, "one_nn_dist" ] = distances [:, 1 ]
126209 ad .obs .loc [:, "one_nn_group" ] = (
127- ad .obs [group_col ].iloc [ad . obs [ "one_nn_idx" ].tolist ()].tolist ()
210+ ad .obs [group_col ].iloc [indices [:, 1 ].tolist ()].tolist ()
128211 )
129212
130213 # Normalized count table of nearest neighbor groups
@@ -245,19 +328,27 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
245328
246329 if group != reference_group :
247330 # Add to .obs
331+ ad .obs .loc [in_group , "Closest reference" ] = (
332+ ad [in_ref , :].obs_names [distances .argmin (1 )].tolist ()
333+ )
248334 ad .obs .loc [in_group , "Distance to closest reference" ] = distances .min (1 )
249335 else :
250336 # If the group is the reference group, the nearest neighbor will be
251337 # the sequence itself. So we find the next nearest neighbor.
252- dlist = []
338+ indices = []
339+ dists = []
253340 for i , row in enumerate (distances ):
254341 # Drop the nearest neighbor
255- row = np .delete ( row , i )
342+ row [ i ] = np .Inf
256343 # Take the new minimum
257- dlist .append (row .min ())
344+ dists .append (row .min ())
345+ indices .append (row .argmin ())
258346
259347 # Add to .obs
260- ad .obs .loc [in_group , "Distance to closest reference" ] = dlist
348+ ad .obs .loc [in_group , "Closest reference" ] = (
349+ ad [in_ref , :].obs_names [indices ].tolist ()
350+ )
351+ ad .obs .loc [in_group , "Distance to closest reference" ] = dists
261352
262353 # Mann-whitney or Kruskal-wallis test
263354 if len (groups ) == 2 :
@@ -276,10 +367,16 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
276367
277368
278369def embedding_analysis (
279- matrix , seqs , reference_group , group_col = "Group" , n_neighbors = 15 , use_pca = False
370+ matrix ,
371+ seqs ,
372+ reference_group ,
373+ group_col = "Group" ,
374+ n_neighbors = 15 ,
375+ use_pca = False ,
376+ max_iter = 1000 ,
280377):
281378 """
282- A single function to calculate all embedding distance-based metrics.
379+ A single function to calculate several embedding distance-based metrics.
283380
284381 Args:
285382 matrix (np.array, pd.DataFrame): A probability table or embedding matrix
@@ -307,12 +404,15 @@ def embedding_analysis(
307404 (.obs['Distance to closest reference'])
308405 Test for between group difference in Distance to the closest reference
309406 sequence (.uns["ref_dist_test"])
407+ Performance metrics for SVMs trained to classify each group from the
408+ reference (.uns["svm_performance"])
409+ Predictions by SVMs trained to classify each group from the reference
410+ (.obs["{group}_SVM_predicted_reference"])
310411 """
311412 from anndata import AnnData
312413
313414 print ("Creating AnnData object" )
314- ad = AnnData (matrix )
315- ad .obs = seqs
415+ ad = AnnData (matrix , obs = seqs )
316416 ad = ad [:, ad .X .sum (0 ) > 0 ]
317417
318418 print ("PCA" )
@@ -325,11 +425,22 @@ def embedding_analysis(
325425 ad = differential_analysis (ad , reference_group , group_col )
326426
327427 print ("1-NN statistics" )
328- ad = one_nn_stats (ad , reference_group , group_col , use_pca = use_pca )
428+ ad = groupwise_1nn (ad , reference_group , group_col , use_pca = use_pca )
329429
330430 print ("Within-group KNN diversity" )
331431 ad = within_group_knn_dist (ad , n_neighbors , group_col , use_pca = use_pca )
332432
333433 print ("Euclidean distance to nearest reference" )
334434 ad = dist_to_reference (ad , reference_group , group_col , use_pca = use_pca )
435+
436+ print ("Train groupwise classifiers" )
437+ ad = groupwise_svm (
438+ ad ,
439+ reference_group ,
440+ group_col = group_col ,
441+ cv = 5 ,
442+ is_kernel = False ,
443+ max_iter = max_iter ,
444+ )
445+
335446 return ad
0 commit comments