@@ -414,6 +414,44 @@ def distribution_shift(ad, reference_group, group_col="Group", use_pca=False):
414414 return ad
415415
416416
417+ def geometric_sketch (ad , N , groups = None , group_col = "Group" , use_pca = True ):
418+ """
419+ Applies geometric sketching (Hie, Brian et al. Cell Systems, Volume 8,
420+ Issue 6, 483 - 493.e7) to sample a subset of sequences that represent
421+ the diversity in the specified groups.
422+
423+ Args:
424+ ad (anndata.AnnData): Anndata object containing sequence embeddings
425+ of shape (n_seqs x n_vars)
426+ N (int): Number of sequences to sample from each group
427+ groups (list): Names of groups from which to sample sequences. If None,
428+ all groups are used.
429+ group_col (str): Name of column in .obs containing group ID
430+ use_pca (bool): Whether to use PCA distances
431+
432+ Returns:
433+ ad (anndata.AnnData): Modified anndata object containing selections in
434+ ad.obs['selected'].
435+ """
436+ from geosketch import gs
437+
438+ ad .obs ["selected" ] = False
439+ groups = groups or ad .obs [group_col ].unique ()
440+
441+ for group in groups :
442+ in_group = ad .obs [group_col ] == group
443+ group_idx = ad .obs_names [in_group ].tolist ()
444+ if use_pca :
445+ group_X = ad .obsm ["X_pca" ][in_group , :]
446+ else :
447+ group_X = ad .X [in_group , :]
448+
449+ sketch_index = gs (group_X , N = N , replace = False )
450+ ad .obs .loc [[group_idx [x ] for x in sketch_index ], "selected" ] = True
451+
452+ return ad
453+
454+
417455def embedding_analysis (
418456 matrix ,
419457 seqs ,
0 commit comments