11"""Doublet detection in single-cell RNA-seq data."""
22
33import collections
4+ import io
45import warnings
6+ from contextlib import redirect_stdout
57
68import anndata
79import numpy as np
810import phenograph
911import scipy .sparse as sp_sparse
10- import tables
1112import scanpy as sc
12- from scipy .io import mmread
1313from scipy .sparse import csr_matrix
1414from scipy .stats import hypergeom
1515from sklearn .utils import check_array
1616from sklearn .utils .sparsefuncs_fast import inplace_csr_row_normalize_l1
1717from tqdm .auto import tqdm
1818
1919
20- def load_10x_h5 (file , genome ):
21- """Load count matrix in 10x H5 format
22- Adapted from:
23- https://support.10xgenomics.com/single-cell-gene-expression/software/
24- pipelines/latest/advanced/h5_matrices
25-
26- Args:
27- file (str): Path to H5 file
28- genome (str): genome, top level h5 group
29-
30- Returns:
31- ndarray: Raw count matrix.
32- ndarray: Barcodes
33- ndarray: Gene names
34- """
35-
36- with tables .open_file (file , "r" ) as f :
37- try :
38- group = f .get_node (f .root , genome )
39- except tables .NoSuchNodeError :
40- print ("That genome does not exist in this file." )
41- return None
42- # gene_ids = getattr(group, 'genes').read()
43- gene_names = getattr (group , "gene_names" ).read ()
44- barcodes = getattr (group , "barcodes" ).read ()
45- data = getattr (group , "data" ).read ()
46- indices = getattr (group , "indices" ).read ()
47- indptr = getattr (group , "indptr" ).read ()
48- shape = getattr (group , "shape" ).read ()
49- matrix = sp_sparse .csc_matrix ((data , indices , indptr ), shape = shape )
50-
51- return matrix , barcodes , gene_names
52-
53-
54- def load_mtx (file ):
55- """Load count matrix in mtx format
56-
57- Args:
58- file (str): Path to mtx file
59-
60- Returns:
61- ndarray: Raw count matrix.
62- """
63- raw_counts = np .transpose (mmread (file ))
64-
65- return raw_counts .tocsc ()
66-
67-
6820class BoostClassifier :
6921 """Classifier for doublets in single-cell RNA-seq data.
7022
@@ -162,8 +114,6 @@ def __init__(
162114 if use_phenograph is True :
163115 if "prune" not in phenograph_parameters :
164116 phenograph_parameters ["prune" ] = True
165- if ("verbosity" not in phenograph_parameters ) and (not self .verbose ):
166- phenograph_parameters ["verbosity" ] = 1
167117 self .phenograph_parameters = phenograph_parameters
168118 if (self .n_iters == 1 ) and (phenograph_parameters .get ("prune" ) is True ):
169119 warn_msg = (
@@ -238,9 +188,7 @@ def fit(self, raw_counts):
238188 self .all_log_p_values_ = np .zeros ((self .n_iters , self ._num_cells ))
239189 all_communities = np .zeros ((self .n_iters , self ._num_cells ))
240190 all_parents = []
241- all_synth_communities = np .zeros (
242- (self .n_iters , int (self .boost_rate * self ._num_cells ))
243- )
191+ all_synth_communities = np .zeros ((self .n_iters , int (self .boost_rate * self ._num_cells )))
244192
245193 for i in tqdm (range (self .n_iters )):
246194 if self .verbose :
@@ -284,9 +232,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
284232 """
285233 log_p_thresh = np .log (p_thresh )
286234 if self .n_iters > 1 :
287- with np .errstate (
288- invalid = "ignore"
289- ): # Silence numpy warning about NaN comparison
235+ with np .errstate (invalid = "ignore" ): # Silence numpy warning about NaN comparison
290236 self .voting_average_ = np .mean (
291237 np .ma .masked_invalid (self .all_log_p_values_ ) <= log_p_thresh , axis = 0
292238 )
@@ -298,20 +244,34 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
298244 # Find a cutoff score
299245 potential_cutoffs = np .unique (self .all_scores_ [~ np .isnan (self .all_scores_ )])
300246 if len (potential_cutoffs ) > 1 :
301- max_dropoff = (
302- np .argmax (potential_cutoffs [1 :] - potential_cutoffs [:- 1 ]) + 1
303- )
247+ max_dropoff = np .argmax (potential_cutoffs [1 :] - potential_cutoffs [:- 1 ]) + 1
304248 else : # Most likely pathological dataset, only one (or no) clusters
305249 max_dropoff = 0
306250 self .suggested_score_cutoff_ = potential_cutoffs [max_dropoff ]
307- with np .errstate (
308- invalid = "ignore"
309- ): # Silence numpy warning about NaN comparison
251+ with np .errstate (invalid = "ignore" ): # Silence numpy warning about NaN comparison
310252 self .labels_ = self .all_scores_ [0 , :] >= self .suggested_score_cutoff_
311253 self .labels_ [np .isnan (self .all_scores_ )[0 , :]] = np .nan
312254
313255 return self .labels_
314256
257+ def doublet_score (self ):
258+ """Produce doublet scores
259+
260+ The doublet score is the average negative log p-value of doublet enrichment
261+ averaged over the iterations. Higher means more likely to be doublet.
262+
263+ Returns:
264+ scores (ndarray, ndims=1): Average negative log p-value over iterations
265+ """
266+
267+ if self .n_iters > 1 :
268+ with np .errstate (invalid = "ignore" ): # Silence numpy warning about NaN comparison
269+ avg_log_p = np .mean (np .ma .masked_invalid (self .all_log_p_values_ ), axis = 0 )
270+ else :
271+ avg_log_p = self .all_log_p_values_ [0 ]
272+
273+ return - avg_log_p
274+
315275 def _one_fit (self ):
316276 if self .verbose :
317277 print ("\n Creating synthetic doublets..." )
@@ -347,19 +307,22 @@ def _one_fit(self):
347307 if self .verbose :
348308 print ("Clustering augmented data set...\n " )
349309 if self .use_phenograph :
350- fullcommunities , _ , _ = phenograph .cluster (
351- aug_counts .obsm ["X_pca" ], ** self .phenograph_parameters
352- )
310+ f = io .StringIO ()
311+ with redirect_stdout (f ):
312+ fullcommunities , _ , _ = phenograph .cluster (
313+ aug_counts .obsm ["X_pca" ], ** self .phenograph_parameters
314+ )
315+ out = f .getvalue ()
316+ if self .verbose :
317+ print (out )
353318 else :
354319 sc .pp .neighbors (
355320 aug_counts ,
356321 random_state = self .random_state ,
357322 method = "umap" ,
358323 n_neighbors = 10 ,
359324 )
360- sc .tl .louvain (
361- aug_counts , random_state = self .random_state , resolution = 4 , directed = False
362- )
325+ sc .tl .louvain (aug_counts , random_state = self .random_state , resolution = 4 , directed = False )
363326 fullcommunities = np .array (aug_counts .obs ["louvain" ], dtype = int )
364327 min_ID = min (fullcommunities )
365328 self .communities_ = fullcommunities [: self ._num_cells ]
@@ -380,8 +343,7 @@ def _one_fit(self):
380343 orig_cells_per_comm = collections .Counter (self .communities_ )
381344 community_IDs = orig_cells_per_comm .keys ()
382345 community_scores = {
383- i : float (synth_cells_per_comm [i ])
384- / (synth_cells_per_comm [i ] + orig_cells_per_comm [i ])
346+ i : float (synth_cells_per_comm [i ]) / (synth_cells_per_comm [i ] + orig_cells_per_comm [i ])
385347 for i in community_IDs
386348 }
387349 scores = np .array ([community_scores [i ] for i in self .communities_ ])
@@ -412,9 +374,7 @@ def _createDoublets(self):
412374 num_synths = int (self .boost_rate * self ._num_cells )
413375
414376 # Parent indices
415- choices = np .random .choice (
416- self ._num_cells , size = (num_synths , 2 ), replace = self .replace
417- )
377+ choices = np .random .choice (self ._num_cells , size = (num_synths , 2 ), replace = self .replace )
418378 parents = [list (p ) for p in choices ]
419379
420380 parent0 = self ._raw_counts [choices [:, 0 ], :]
0 commit comments