Skip to content

Commit 4c8d2ff

Browse files
avantikalallala8
andauthored
Ttest (#16)
* added hotellings t2 test * formatting --------- Co-authored-by: lala8 <[email protected]>
1 parent 3cf32ef commit 4c8d2ff

File tree

3 files changed

+5549
-1136
lines changed

3 files changed

+5549
-1136
lines changed

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ install_requires =
6161
scikit_posthocs
6262
scipy
6363
plotnine
64+
hotelling
6465
# fastsk
66+
upsetplot
6567

6668

6769
[options.packages.find]

src/polygraph/embedding.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import pandas as pd
33
import scanpy as sc
4+
from hotelling.stats import hotelling_t2
45
from scipy.stats import fisher_exact
56
from sklearn.metrics import pairwise_distances
67
from sklearn.neighbors import NearestNeighbors
@@ -89,7 +90,7 @@ def differential_analysis(ad, reference_group, group_col="Group"):
8990
return ad
9091

9192

92-
def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False):
93+
def reference_1nn(ad, reference_group, group_col="Group", use_pca=False):
9394
"""
9495
For each sequence, find its nearest neighbor among its own group or
9596
the reference group based on the sequence embeddings.
@@ -175,7 +176,7 @@ def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False):
175176
return ad
176177

177178

178-
def joint_1nn(ad, reference_group, group_col="Group", use_pca=False):
179+
def all_1nn(ad, reference_group, group_col="Group", use_pca=False):
179180
"""
180181
Find the group ID of each sequence's 1-nearest neighbor statistics based on the
181182
sequence embeddings. Compare all groups to all other groups.
@@ -234,7 +235,7 @@ def joint_1nn(ad, reference_group, group_col="Group", use_pca=False):
234235
return ad
235236

236237

237-
def within_group_knn_dist(ad, n_neighbors=10, group_col="Group", use_pca=False):
238+
def group_diversity(ad, n_neighbors=10, group_col="Group", use_pca=False):
238239
"""
239240
Calculates the mean distance of each sequence to its k nearest neighbors in the
240241
same group, in the embedding space. Metric of diversity
@@ -366,6 +367,53 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
366367
return ad
367368

368369

370+
def distribution_shift(ad, reference_group, group_col="Group", use_pca=False):
371+
"""
372+
Compare the distribution of sequences in each group to the distribution
373+
of reference sequences, in the embedding space. Performs Hotelling's T2
374+
test to compare multivariate distributions.
375+
376+
Args:
377+
ad (anndata.AnnData): Anndata object containing sequence embeddings
378+
of shape (n_seqs x n_vars)
379+
reference_group (str): ID of group to use as reference
380+
group_col (str): Name of column in .obs containing group ID
381+
use_pca (bool): Whether to use PCA distances
382+
383+
Returns:
384+
ad (anndata.AnnData): Modified anndata object containing distance to
385+
reference in .uns['distribution_shift'].
386+
"""
387+
rows = []
388+
389+
# Get reference sequences
390+
in_ref = ad.obs[group_col] == reference_group
391+
if use_pca:
392+
ref_X = ad.obsm["X_pca"][in_ref, :]
393+
else:
394+
ref_X = ad.X[in_ref, :]
395+
396+
# List groups
397+
groups = ad.obs[group_col].unique()
398+
399+
for group in groups:
400+
# Get group sequences
401+
in_group = ad.obs[group_col] == group
402+
if use_pca:
403+
group_X = ad.obsm["X_pca"][in_group, :]
404+
else:
405+
group_X = ad.X[in_group, :]
406+
407+
# Perform Hotelling's T2 test to compare to the reference
408+
rows.append([group] + list(hotelling_t2(group_X, ref_X)[:-1]))
409+
410+
# Format dataframe
411+
res = pd.DataFrame(rows, columns=[group_col, "t2_stat", "fval", "pval"])
412+
res["padj"] = fdrcorrection(res.pval)[1]
413+
ad.uns["dist_shift_test"] = res.set_index(group_col)
414+
return ad
415+
416+
369417
def embedding_analysis(
370418
matrix,
371419
seqs,
@@ -425,14 +473,17 @@ def embedding_analysis(
425473
ad = differential_analysis(ad, reference_group, group_col)
426474

427475
print("1-NN statistics")
428-
ad = groupwise_1nn(ad, reference_group, group_col, use_pca=use_pca)
476+
ad = reference_1nn(ad, reference_group, group_col, use_pca=use_pca)
429477

430478
print("Within-group KNN diversity")
431-
ad = within_group_knn_dist(ad, n_neighbors, group_col, use_pca=use_pca)
479+
ad = group_diversity(ad, n_neighbors, group_col, use_pca=use_pca)
432480

433481
print("Euclidean distance to nearest reference")
434482
ad = dist_to_reference(ad, reference_group, group_col, use_pca=use_pca)
435483

484+
print("Distribution shift")
485+
ad = distribution_shift(ad, reference_group, group_col, use_pca=use_pca)
486+
436487
print("Train groupwise classifiers")
437488
ad = groupwise_svm(
438489
ad,

tutorials/1_yeast_tutorial.ipynb

Lines changed: 5491 additions & 1131 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)