Skip to content

Commit ce4f340

Browse files
avantikalallala8
andauthored
Avantika dev (#14)
* added evolution, reran tutorials * added evolve script * linting * fixed tests * linting --------- Co-authored-by: lala8 <[email protected]>
1 parent 79a3011 commit ce4f340

File tree

16 files changed

+4974
-3553
lines changed

16 files changed

+4974
-3553
lines changed

src/polygraph/embedding.py

Lines changed: 131 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import numpy as np
22
import pandas as pd
33
import scanpy as sc
4+
from scipy.stats import fisher_exact
45
from sklearn.metrics import pairwise_distances
56
from sklearn.neighbors import NearestNeighbors
7+
from statsmodels.stats.multitest import fdrcorrection
68

9+
from polygraph.classifier import groupwise_svm
710
from polygraph.stats import groupwise_fishers, groupwise_mann_whitney, kruskal_dunn
811

912

1013
def embedding_pca(ad, **kwargs):
1114
"""
12-
Perform PCA on sequence embeddings
15+
Perform PCA on sequence embeddings.
1316
1417
Args:
1518
ad (anndata.AnnData): Anndata object containing sequence embeddings
@@ -21,10 +24,7 @@ def embedding_pca(ad, **kwargs):
2124
"""
2225
sc.pp.pca(ad, **kwargs)
2326
print(
24-
"Fraction of variance explained: ", np.round(ad.uns["pca"]["variance_ratio"], 2)
25-
)
26-
print(
27-
"Fraction of total variance explained: ",
27+
"Fraction of total variance explained by PCA components: ",
2828
np.sum(ad.uns["pca"]["variance_ratio"]),
2929
)
3030
return ad
@@ -89,12 +89,10 @@ def differential_analysis(ad, reference_group, group_col="Group"):
8989
return ad
9090

9191

92-
def one_nn_stats(ad, reference_group, group_col="Group", use_pca=False):
92+
def groupwise_1nn(ad, reference_group, group_col="Group", use_pca=False):
9393
"""
94-
Calculate the following 1-nearest neighbor statistics based on the
95-
sequence embeddings:
96-
1. Group ID of nearest neighbor
97-
2. Distance to nearest neighbor
94+
For each sequence, find its nearest neighbor among its own group or
95+
the reference group based on the sequence embeddings.
9896
9997
Args:
10098
ad (anndata.AnnData): Anndata object containing sequence embeddings of
@@ -109,7 +107,92 @@ def one_nn_stats(ad, reference_group, group_col="Group", use_pca=False):
109107
probability table of nearest neighbor groups for each group
110108
in .uns
111109
"""
110+
res = pd.DataFrame()
111+
112+
# Get reference embedding
113+
in_ref = ad.obs[group_col] == reference_group
114+
115+
# List groups
116+
groups = ad.obs[group_col].unique()
117+
118+
# List nonreference groups
119+
nonreference_groups = groups[groups != reference_group]
120+
121+
for group in nonreference_groups:
122+
# Get group embeddings
123+
in_group = ad.obs[group_col] == group
124+
in_group_or_ref = (in_ref | in_group).tolist()
125+
in_group_or_ref_indices = ad.obs.index[in_group_or_ref].tolist()
126+
if use_pca:
127+
X = ad.obsm["X_pca"][in_group_or_ref, :]
128+
else:
129+
X = ad.X[in_group_or_ref, :]
130+
131+
# Calculate nearest neighbors
132+
nbrs = NearestNeighbors(n_neighbors=2, algorithm="ball_tree").fit(X)
133+
distances, indices = nbrs.kneighbors(X)
134+
indices = np.array(in_group_or_ref_indices)[indices[:, 1]]
135+
136+
# Record whether nearest neighbor is a member of the reference group
137+
ad.obs.loc[in_group_or_ref, f"{group}_one_nn_idx"] = indices
138+
ad.obs.loc[in_group_or_ref, f"{group}_one_nn_dist"] = distances[:, 1]
139+
ad.obs.loc[in_group_or_ref, f"{group}_one_nn_group"] = ad.obs.loc[
140+
indices, group_col
141+
].tolist()
142+
143+
# Make contingency table
144+
cont = ad.obs[[group_col, f"{group}_one_nn_group"]].value_counts()
145+
cont = (
146+
cont.unstack()
147+
.fillna(0)
148+
.loc[[group, reference_group], [group, reference_group]]
149+
.values
150+
)
151+
152+
# Perform tests
153+
group_prop = cont[0, 1] / cont[0, :].sum()
154+
ref_prop = cont[1, 1] / cont[1, :].sum()
155+
156+
res = pd.concat(
157+
[
158+
res,
159+
pd.DataFrame(
160+
{
161+
group_col: [group],
162+
"group_prop": [group_prop],
163+
"ref_prop": [ref_prop],
164+
"pval": [fisher_exact(cont, alternative="two-sided").pvalue],
165+
}
166+
),
167+
]
168+
)
169+
170+
# Save results
171+
res = res.set_index(group_col)
172+
res["padj"] = fdrcorrection(res.pval)[1]
173+
ad.uns["1NN_group_probs"] = res.iloc[:, :2].copy()
174+
ad.uns["1NN_ref_prop_test"] = res
175+
return ad
176+
177+
178+
def joint_1nn(ad, reference_group, group_col="Group", use_pca=False):
179+
"""
180+
Find the group ID of each sequence's 1-nearest neighbor statistics based on the
181+
sequence embeddings. Compare all groups to all other groups.
182+
183+
Args:
184+
ad (anndata.AnnData): Anndata object containing sequence embeddings of
185+
shape (n_seqs x n_vars)
186+
reference_group (str): ID of group to use as reference
187+
group_col (str): Name of column in .obs containing group ID
188+
use_pca (bool): Whether to use PCA distances
112189
190+
Returns:
191+
ad (anndata.AnnData): Modified anndata object containing index, distance and
192+
group ID of each sequence's nearest neighbor in .obs, as well as a
193+
probability table of nearest neighbor groups for each group
194+
in .uns
195+
"""
113196
# Get nearest neighbor for each sequence
114197
if use_pca:
115198
nbrs = NearestNeighbors(n_neighbors=2, algorithm="ball_tree").fit(
@@ -124,7 +207,7 @@ def one_nn_stats(ad, reference_group, group_col="Group", use_pca=False):
124207
ad.obs.loc[:, "one_nn_idx"] = indices[:, 1]
125208
ad.obs.loc[:, "one_nn_dist"] = distances[:, 1]
126209
ad.obs.loc[:, "one_nn_group"] = (
127-
ad.obs[group_col].iloc[ad.obs["one_nn_idx"].tolist()].tolist()
210+
ad.obs[group_col].iloc[indices[:, 1].tolist()].tolist()
128211
)
129212

130213
# Normalized count table of nearest neighbor groups
@@ -245,19 +328,27 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
245328

246329
if group != reference_group:
247330
# Add to .obs
331+
ad.obs.loc[in_group, "Closest reference"] = (
332+
ad[in_ref, :].obs_names[distances.argmin(1)].tolist()
333+
)
248334
ad.obs.loc[in_group, "Distance to closest reference"] = distances.min(1)
249335
else:
250336
# If the group is the reference group, the nearest neighbor will be
251337
# the sequence itself. So we find the next nearest neighbor.
252-
dlist = []
338+
indices = []
339+
dists = []
253340
for i, row in enumerate(distances):
254341
# Drop the nearest neighbor
255-
row = np.delete(row, i)
342+
row[i] = np.Inf
256343
# Take the new minimum
257-
dlist.append(row.min())
344+
dists.append(row.min())
345+
indices.append(row.argmin())
258346

259347
# Add to .obs
260-
ad.obs.loc[in_group, "Distance to closest reference"] = dlist
348+
ad.obs.loc[in_group, "Closest reference"] = (
349+
ad[in_ref, :].obs_names[indices].tolist()
350+
)
351+
ad.obs.loc[in_group, "Distance to closest reference"] = dists
261352

262353
# Mann-whitney or Kruskal-wallis test
263354
if len(groups) == 2:
@@ -276,10 +367,16 @@ def dist_to_reference(ad, reference_group, group_col="Group", use_pca=False):
276367

277368

278369
def embedding_analysis(
279-
matrix, seqs, reference_group, group_col="Group", n_neighbors=15, use_pca=False
370+
matrix,
371+
seqs,
372+
reference_group,
373+
group_col="Group",
374+
n_neighbors=15,
375+
use_pca=False,
376+
max_iter=1000,
280377
):
281378
"""
282-
A single function to calculate all embedding distance-based metrics.
379+
A single function to calculate several embedding distance-based metrics.
283380
284381
Args:
285382
matrix (np.array, pd.DataFrame): A probability table or embedding matrix
@@ -307,12 +404,15 @@ def embedding_analysis(
307404
(.obs['Distance to closest reference'])
308405
Test for between group difference in Distance to the closest reference
309406
sequence (.uns["ref_dist_test"])
407+
Performance metrics for SVMs trained to classify each group from the
408+
reference (.uns["svm_performance"])
409+
Predictions by SVMs trained to classify each group from the reference
410+
(.obs["{group}_SVM_predicted_reference"])
310411
"""
311412
from anndata import AnnData
312413

313414
print("Creating AnnData object")
314-
ad = AnnData(matrix)
315-
ad.obs = seqs
415+
ad = AnnData(matrix, obs=seqs)
316416
ad = ad[:, ad.X.sum(0) > 0]
317417

318418
print("PCA")
@@ -325,11 +425,22 @@ def embedding_analysis(
325425
ad = differential_analysis(ad, reference_group, group_col)
326426

327427
print("1-NN statistics")
328-
ad = one_nn_stats(ad, reference_group, group_col, use_pca=use_pca)
428+
ad = groupwise_1nn(ad, reference_group, group_col, use_pca=use_pca)
329429

330430
print("Within-group KNN diversity")
331431
ad = within_group_knn_dist(ad, n_neighbors, group_col, use_pca=use_pca)
332432

333433
print("Euclidean distance to nearest reference")
334434
ad = dist_to_reference(ad, reference_group, group_col, use_pca=use_pca)
435+
436+
print("Train groupwise classifiers")
437+
ad = groupwise_svm(
438+
ad,
439+
reference_group,
440+
group_col=group_col,
441+
cv=5,
442+
is_kernel=False,
443+
max_iter=max_iter,
444+
)
445+
335446
return ad

src/polygraph/evolve.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
import anndata
2+
import pandas as pd
3+
import scanpy as sc
4+
from sklearn.metrics import pairwise_distances
5+
6+
from polygraph.models import get_embeddings, predict
7+
from polygraph.sequence import ISM, kmer_frequencies
8+
9+
10+
def evolve(
11+
start_seq,
12+
reference_seqs,
13+
iter,
14+
model,
15+
k=None,
16+
drop_last_layers=None,
17+
batch_size=512,
18+
device="cpu",
19+
task=None,
20+
alpha=3,
21+
):
22+
"""
23+
Directed evolution with an additional goal to increase similarity to
24+
reference sequences.
25+
26+
Args:
27+
start_seq (str): Start sequence
28+
reference_seqs (list): Reference sequences
29+
iter (int): Number of iterations
30+
model (nn.Sequential): Torch sequential model
31+
k (int): k-mer length for k-mer embedding.
32+
drop_last_layers (int): Number of terminal layers to drop from the
33+
model for model embedding.
34+
batch_size (int): Batch size for inference
35+
device (int, str): Index of device to use for inference
36+
task (int): Model output head. If None, average all heads.
37+
alpha (int): Relative weight for similarity
38+
39+
Returns:
40+
best_seq (str): Optimized sequence
41+
"""
42+
# Embed the reference sequences
43+
if k is not None:
44+
reference_embeddings = kmer_frequencies(reference_seqs, k=k, normalize=True)
45+
elif drop_last_layers is not None:
46+
reference_embeddings = get_embeddings(
47+
reference_seqs,
48+
model,
49+
batch_size=batch_size,
50+
drop_last_layers=drop_last_layers,
51+
device=device,
52+
)
53+
else:
54+
raise ValueError("One of k or drop_last_layers should be provided.")
55+
reference_ad = anndata.AnnData(reference_embeddings)
56+
57+
for i in range(1, iter + 1):
58+
print(f"Iter: {i}")
59+
if i == 1:
60+
curr_seqs = [start_seq]
61+
62+
# Embed the evolved sequences
63+
if k is not None:
64+
curr_embeddings = kmer_frequencies(curr_seqs, k=k, normalize=True)
65+
else:
66+
curr_embeddings = get_embeddings(
67+
curr_seqs,
68+
model,
69+
batch_size=batch_size,
70+
drop_last_layers=drop_last_layers,
71+
device=device,
72+
)
73+
curr_ad = anndata.AnnData(
74+
curr_embeddings, obs=pd.DataFrame({"Sequence": curr_seqs})
75+
)
76+
77+
# Predict on evolved sequences
78+
curr_ad.obs["pred"] = predict(
79+
curr_seqs, model, batch_size=batch_size, device=device
80+
)
81+
82+
# Combine
83+
ad = anndata.concat(
84+
[reference_ad, curr_ad], index_unique="_", keys=["ref", "curr"]
85+
)
86+
87+
# PCA
88+
sc.pp.pca(ad, n_comps=50)
89+
90+
# Get PCA embeddings for evolved and reference sequences
91+
reference_X = ad.obsm["X_pca"][: len(reference_ad), :]
92+
curr_X = ad.obsm["X_pca"][len(reference_ad) :, :]
93+
94+
# Get euclidean distance of each evolved sequence to its closest
95+
# reference sequence
96+
curr_ad.obs["distance"] = pairwise_distances(
97+
curr_X, reference_X, metric="euclidean"
98+
).min(1)
99+
100+
# Assign each sequence a total score
101+
curr_ad.obs["score"] = curr_ad.obs["pred"] - (alpha * curr_ad.obs["distance"])
102+
103+
# Select best sequence from current iteration
104+
best = curr_ad.obs.sort_values("score").tail(1)
105+
best_seq = best.Sequence.values[0]
106+
107+
# Get sequences for next round
108+
if i < iter:
109+
curr_seqs = ISM(best_seq)
110+
else:
111+
return best_seq
112+
113+
print(
114+
(
115+
f"Prediction: {best.pred.values[0]} | "
116+
+ f"Distance: {best.distance.values[0]} | "
117+
+ f"Score: {best.score.values[0]}"
118+
)
119+
)

0 commit comments

Comments
 (0)