Skip to content

Commit b0a0f48

Browse files
avantikalallala8
andauthored
Geometric sketching (#18)
* added robustness analysis * added geometric sketching * fix --------- Co-authored-by: lala8 <[email protected]>
1 parent 9c4a37c commit b0a0f48

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

setup.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ install_requires =
6464
hotelling
6565
# fastsk
6666
upsetplot
67+
geosketch
6768

6869

6970
[options.packages.find]

src/polygraph/embedding.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,44 @@ def distribution_shift(ad, reference_group, group_col="Group", use_pca=False):
414414
return ad
415415

416416

417+
def geometric_sketch(ad, N, groups=None, group_col="Group", use_pca=True):
418+
"""
419+
Applies geometric sketching (Hie, Brian et al. Cell Systems, Volume 8,
420+
Issue 6, 483 - 493.e7) to sample a subset of sequences that represent
421+
the diversity in the specified groups.
422+
423+
Args:
424+
ad (anndata.AnnData): Anndata object containing sequence embeddings
425+
of shape (n_seqs x n_vars)
426+
N (int): Number of sequences to sample from each group
427+
groups (list): Names of groups from which to sample sequences. If None,
428+
all groups are used.
429+
group_col (str): Name of column in .obs containing group ID
430+
use_pca (bool): Whether to use PCA distances
431+
432+
Returns:
433+
ad (anndata.AnnData): Modified anndata object containing selections in
434+
ad.obs['selected'].
435+
"""
436+
from geosketch import gs
437+
438+
ad.obs["selected"] = False
439+
groups = groups or ad.obs[group_col].unique()
440+
441+
for group in groups:
442+
in_group = ad.obs[group_col] == group
443+
group_idx = ad.obs_names[in_group].tolist()
444+
if use_pca:
445+
group_X = ad.obsm["X_pca"][in_group, :]
446+
else:
447+
group_X = ad.X[in_group, :]
448+
449+
sketch_index = gs(group_X, N=N, replace=False)
450+
ad.obs.loc[[group_idx[x] for x in sketch_index], "selected"] = True
451+
452+
return ad
453+
454+
417455
def embedding_analysis(
418456
matrix,
419457
seqs,

0 commit comments

Comments
 (0)