|
1 | 1 | import numpy as np |
2 | 2 | import pandas as pd |
3 | 3 | import scanpy as sc |
| 4 | +from hotelling.stats import hotelling_t2 |
4 | 5 | from scipy.stats import fisher_exact |
5 | 6 | from sklearn.metrics import pairwise_distances |
6 | 7 | from sklearn.neighbors import NearestNeighbors |
@@ -89,7 +90,7 @@ def differential_analysis(ad, reference_group, group_col="Group"): |
89 | 90 | return ad |
90 | 91 |
|
91 | 92 |
|
92 | | -def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False): |
| 93 | +def reference_1nn(ad, reference_group, group_col="Group", use_pca=False): |
93 | 94 | """ |
94 | 95 | For each sequence, find its nearest neighbor among its own group or |
95 | 96 | the reference group based on the sequence embeddings. |
@@ -175,7 +176,7 @@ def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False): |
175 | 176 | return ad |
176 | 177 |
|
177 | 178 |
|
178 | | -def joint_1nn(ad, reference_group, group_col="Group", use_pca=False): |
| 179 | +def all_1nn(ad, reference_group, group_col="Group", use_pca=False): |
179 | 180 | """ |
180 | 181 | Find the group ID of each sequence's 1-nearest neighbor statistics based on the |
181 | 182 | sequence embeddings. Compare all groups to all other groups. |
@@ -234,7 +235,7 @@ def joint_1nn(ad, reference_group, group_col="Group", use_pca=False): |
234 | 235 | return ad |
235 | 236 |
|
236 | 237 |
|
237 | | -def within_group_knn_dist(ad, n_neighbors=10, group_col="Group", use_pca=False): |
| 238 | +def group_diversity(ad, n_neighbors=10, group_col="Group", use_pca=False): |
238 | 239 | """ |
239 | 240 | Calculates the mean distance of each sequence to its k nearest neighbors in the |
240 | 241 | same group, in the embedding space. Metric of diversity |
@@ -366,6 +367,53 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False): |
366 | 367 | return ad |
367 | 368 |
|
368 | 369 |
|
| 370 | +def distribution_shift(ad, reference_group, group_col="Group", use_pca=False): |
| 371 | + """ |
| 372 | + Compare the distribution of sequences in each group to the distribution |
| 373 | + of reference sequences, in the embedding space. Performs Hotelling's T2 |
| 374 | + test to compare multivariate distributions. |
| 375 | +
|
| 376 | + Args: |
| 377 | + ad (anndata.AnnData): Anndata object containing sequence embeddings |
| 378 | + of shape (n_seqs x n_vars) |
| 379 | + reference_group (str): ID of group to use as reference |
| 380 | + group_col (str): Name of column in .obs containing group ID |
| 381 | + use_pca (bool): Whether to use PCA distances |
| 382 | +
|
| 383 | + Returns: |
| 384 | + ad (anndata.AnnData): Modified anndata object containing distance to |
| 385 | + reference in .uns['distribution_shift']. |
| 386 | + """ |
| 387 | + rows = [] |
| 388 | + |
| 389 | + # Get reference sequences |
| 390 | + in_ref = ad.obs[group_col] == reference_group |
| 391 | + if use_pca: |
| 392 | + ref_X = ad.obsm["X_pca"][in_ref, :] |
| 393 | + else: |
| 394 | + ref_X = ad.X[in_ref, :] |
| 395 | + |
| 396 | + # List groups |
| 397 | + groups = ad.obs[group_col].unique() |
| 398 | + |
| 399 | + for group in groups: |
| 400 | + # Get group sequences |
| 401 | + in_group = ad.obs[group_col] == group |
| 402 | + if use_pca: |
| 403 | + group_X = ad.obsm["X_pca"][in_group, :] |
| 404 | + else: |
| 405 | + group_X = ad.X[in_group, :] |
| 406 | + |
| 407 | + # Perform Hotelling's T2 test to compare to the reference |
| 408 | + rows.append([group] + list(hotelling_t2(group_X, ref_X)[:-1])) |
| 409 | + |
| 410 | + # Format dataframe |
| 411 | + res = pd.DataFrame(rows, columns=[group_col, "t2_stat", "fval", "pval"]) |
| 412 | + res["padj"] = fdrcorrection(res.pval)[1] |
| 413 | + ad.uns["dist_shift_test"] = res.set_index(group_col) |
| 414 | + return ad |
| 415 | + |
| 416 | + |
369 | 417 | def embedding_analysis( |
370 | 418 | matrix, |
371 | 419 | seqs, |
@@ -425,14 +473,17 @@ def embedding_analysis( |
425 | 473 | ad = differential_analysis(ad, reference_group, group_col) |
426 | 474 |
|
427 | 475 | print("1-NN statistics") |
428 | | - ad = groupwise_1nn(ad, reference_group, group_col, use_pca=use_pca) |
| 476 | + ad = reference_1nn(ad, reference_group, group_col, use_pca=use_pca) |
429 | 477 |
|
430 | 478 | print("Within-group KNN diversity") |
431 | | - ad = within_group_knn_dist(ad, n_neighbors, group_col, use_pca=use_pca) |
| 479 | + ad = group_diversity(ad, n_neighbors, group_col, use_pca=use_pca) |
432 | 480 |
|
433 | 481 | print("Euclidean distance to nearest reference") |
434 | 482 | ad = dist_to_reference(ad, reference_group, group_col, use_pca=use_pca) |
435 | 483 |
|
| 484 | + print("Distribution shift") |
| 485 | + ad = distribution_shift(ad, reference_group, group_col, use_pca=use_pca) |
| 486 | + |
436 | 487 | print("Train groupwise classifiers") |
437 | 488 | ad = groupwise_svm( |
438 | 489 | ad, |
|
0 commit comments