11"""Doublet detection in single-cell RNA-seq data."""
22
33import collections
4+ from collections .abc import Callable
45import io
56import warnings
67from contextlib import redirect_stdout
78
89import anndata
910import numpy as np
11+ from numpy .typing import NDArray
1012import phenograph
1113import scanpy as sc
1214import scipy .sparse as sp_sparse
@@ -83,20 +85,20 @@ class BoostClassifier:
8385
8486 def __init__ (
8587 self ,
86- boost_rate = 0.25 ,
87- n_components = 30 ,
88- n_top_var_genes = 10000 ,
89- replace = False ,
90- clustering_algorithm = "phenograph" ,
91- clustering_kwargs = None ,
92- n_iters = 10 ,
93- normalizer = None ,
94- pseudocount = 0.1 ,
95- random_state = 0 ,
96- verbose = False ,
97- standard_scaling = False ,
98- n_jobs = 1 ,
99- ):
88+ boost_rate : float = 0.25 ,
89+ n_components : int = 30 ,
90+ n_top_var_genes : int = 10000 ,
91+ replace : bool = False ,
92+ clustering_algorithm : str = "phenograph" ,
93+ clustering_kwargs : dict | None = None ,
94+ n_iters : int = 10 ,
95+ normalizer : Callable | None = None ,
96+ pseudocount : float = 0.1 ,
97+ random_state : int = 0 ,
98+ verbose : bool = False ,
99+ standard_scaling : bool = False ,
100+ n_jobs : int = 1 ,
101+ ) -> None :
100102 self .boost_rate = boost_rate
101103 self .replace = replace
102104 self .clustering_algorithm = clustering_algorithm
@@ -145,7 +147,7 @@ def __init__(
145147 n_components , n_top_var_genes
146148 )
147149
148- def fit (self , raw_counts ) :
150+ def fit (self , raw_counts : NDArray | sp_sparse . csr_matrix ) -> "BoostClassifier" :
149151 """Fits the classifier on raw_counts.
150152
151153 Args:
@@ -226,7 +228,7 @@ def fit(self, raw_counts):
226228
227229 return self
228230
229- def predict (self , p_thresh = 1e-7 , voter_thresh = 0.9 ):
231+ def predict (self , p_thresh : float = 1e-7 , voter_thresh : float = 0.9 ) -> NDArray :
230232 """Produce doublet calls from fitted classifier
231233
232234 Args:
@@ -266,7 +268,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
266268
267269 return self .labels_
268270
269- def doublet_score (self ):
271+ def doublet_score (self ) -> NDArray :
270272 """Produce doublet scores
271273
272274 The doublet score is the average negative log p-value of doublet enrichment
@@ -284,7 +286,7 @@ def doublet_score(self):
284286
285287 return - avg_log_p
286288
287- def _one_fit (self ):
289+ def _one_fit (self ) -> tuple [ NDArray , NDArray ] :
288290 if self .verbose :
289291 print ("\n Creating synthetic doublets..." )
290292 self ._createDoublets ()
@@ -395,7 +397,7 @@ def _one_fit(self):
395397
396398 return scores , log_p_values
397399
398- def _createDoublets (self ):
400+ def _createDoublets (self ) -> None :
399401 """Create synthetic doublets.
400402
401403 Sets .parents_
@@ -414,7 +416,7 @@ def _createDoublets(self):
414416 self ._raw_synthetics = synthetic
415417 self .parents_ = parents
416418
417- def _set_clustering_kwargs (self ):
419+ def _set_clustering_kwargs (self ) -> None :
418420 """Sets .clustering_kwargs"""
419421 if self .clustering_algorithm == "phenograph" :
420422 if "prune" not in self .clustering_kwargs :
0 commit comments