Skip to content

Commit 5164a08

Browse files
author
lala8
committed
added classifier code
1 parent 00dcc89 commit 5164a08

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed

src/polygraph/classifier.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from collections import defaultdict
2+
3+
import pandas as pd
4+
from sklearn import metrics
5+
from sklearn.calibration import CalibratedClassifierCV
6+
from sklearn.svm import LinearSVC
7+
8+
9+
def groupwise_svm(
10+
ad,
11+
reference_group,
12+
group_col="Group",
13+
cv=5,
14+
is_kernel=True,
15+
max_iter=1000,
16+
use_pca=False,
17+
):
18+
"""
19+
Train an SVM to distinguish between each non-reference group and the reference group
20+
21+
Args:
22+
ad (anndata.AnnData): Anndata object containing sequence embeddings
23+
of shape (n_seqs x n_vars)
24+
reference_group (str): ID of group to use as reference
25+
group_col (str): Name of column in .obs containing group ID
26+
cv (int): Number of cross-validation folds
27+
is_kernel (bool): Whether ad.X is a symmetric kernel matrix
28+
max_iter (int): Maximum number of iterations for SVM
29+
use_pca (bool): Whether to use PCA distances
30+
31+
Returns:
32+
ad (anndata.AnnData): Modified anndata object containing each
33+
sequence's predicted label in .obs, as well as SVM
34+
performance metrics in ad.uns["svm_performance"]
35+
"""
36+
37+
# List groups
38+
groups = ad.obs[group_col].unique()
39+
40+
# List nonreference groups
41+
nonreference_groups = groups[groups != reference_group]
42+
43+
# Get indices of reference sequences
44+
is_ref = ad.obs[group_col] == reference_group
45+
46+
# Dictionary to store performance metrics
47+
perf = defaultdict(list)
48+
49+
# Train SVM per group
50+
for group in nonreference_groups:
51+
# Select sequences for comparison
52+
is_group = ad.obs[group_col] == group
53+
sel = (is_ref | is_group).values
54+
55+
# Get train and test matrices
56+
if use_pca:
57+
Xtrain = ad[sel, :].obsm["X_pca"]
58+
else:
59+
Xtrain = ad[sel, :].X
60+
if is_kernel:
61+
Xtrain = Xtrain[:, sel]
62+
63+
# Get group labels
64+
Ytrain = pd.Categorical(
65+
ad[sel, :].obs[group_col], categories=[group, reference_group]
66+
).codes
67+
68+
# Train SVM
69+
svm = LinearSVC(C=2, max_iter=max_iter)
70+
clf = CalibratedClassifierCV(svm, cv=cv).fit(Xtrain, Ytrain)
71+
72+
# Get predictions
73+
preds = clf.predict(Xtrain)
74+
ad.obs.loc[sel, f"{group}_SVM_predicted_reference"] = preds
75+
76+
# Get metrics
77+
acc = clf.score(Xtrain, Ytrain)
78+
auc = metrics.roc_auc_score(Ytrain, preds)
79+
perf[group_col].append(group)
80+
perf["Accuracy"].append(acc)
81+
perf["AUROC"].append(auc)
82+
83+
ad.uns["svm_performance"] = pd.DataFrame(perf).set_index(group_col)
84+
return ad

0 commit comments

Comments
 (0)