Skip to content

Commit c4a4f42

Browse files
authored
Merge pull request #162 from JonathanShor/typing
Add python typing
2 parents a4156e5 + 4a10465 commit c4a4f42

File tree

2 files changed

+42
-32
lines changed

2 files changed

+42
-32
lines changed

doubletdetection/doubletdetection.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
"""Doublet detection in single-cell RNA-seq data."""
22

33
import collections
4+
from collections.abc import Callable
45
import io
56
import warnings
67
from contextlib import redirect_stdout
78

89
import anndata
910
import numpy as np
11+
from numpy.typing import NDArray
1012
import phenograph
1113
import scanpy as sc
1214
import scipy.sparse as sp_sparse
@@ -83,20 +85,20 @@ class BoostClassifier:
8385

8486
def __init__(
8587
self,
86-
boost_rate=0.25,
87-
n_components=30,
88-
n_top_var_genes=10000,
89-
replace=False,
90-
clustering_algorithm="phenograph",
91-
clustering_kwargs=None,
92-
n_iters=10,
93-
normalizer=None,
94-
pseudocount=0.1,
95-
random_state=0,
96-
verbose=False,
97-
standard_scaling=False,
98-
n_jobs=1,
99-
):
88+
boost_rate: float = 0.25,
89+
n_components: int = 30,
90+
n_top_var_genes: int = 10000,
91+
replace: bool = False,
92+
clustering_algorithm: str = "phenograph",
93+
clustering_kwargs: dict | None = None,
94+
n_iters: int = 10,
95+
normalizer: Callable | None = None,
96+
pseudocount: float = 0.1,
97+
random_state: int = 0,
98+
verbose: bool = False,
99+
standard_scaling: bool = False,
100+
n_jobs: int = 1,
101+
) -> None:
100102
self.boost_rate = boost_rate
101103
self.replace = replace
102104
self.clustering_algorithm = clustering_algorithm
@@ -145,7 +147,7 @@ def __init__(
145147
n_components, n_top_var_genes
146148
)
147149

148-
def fit(self, raw_counts):
150+
def fit(self, raw_counts: NDArray | sp_sparse.csr_matrix) -> "BoostClassifier":
149151
"""Fits the classifier on raw_counts.
150152
151153
Args:
@@ -226,7 +228,7 @@ def fit(self, raw_counts):
226228

227229
return self
228230

229-
def predict(self, p_thresh=1e-7, voter_thresh=0.9):
231+
def predict(self, p_thresh: float = 1e-7, voter_thresh: float = 0.9) -> NDArray:
230232
"""Produce doublet calls from fitted classifier
231233
232234
Args:
@@ -266,7 +268,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
266268

267269
return self.labels_
268270

269-
def doublet_score(self):
271+
def doublet_score(self) -> NDArray:
270272
"""Produce doublet scores
271273
272274
The doublet score is the average negative log p-value of doublet enrichment
@@ -284,7 +286,7 @@ def doublet_score(self):
284286

285287
return -avg_log_p
286288

287-
def _one_fit(self):
289+
def _one_fit(self) -> tuple[NDArray, NDArray]:
288290
if self.verbose:
289291
print("\nCreating synthetic doublets...")
290292
self._createDoublets()
@@ -395,7 +397,7 @@ def _one_fit(self):
395397

396398
return scores, log_p_values
397399

398-
def _createDoublets(self):
400+
def _createDoublets(self) -> None:
399401
"""Create synthetic doublets.
400402
401403
Sets .parents_
@@ -414,7 +416,7 @@ def _createDoublets(self):
414416
self._raw_synthetics = synthetic
415417
self.parents_ = parents
416418

417-
def _set_clustering_kwargs(self):
419+
def _set_clustering_kwargs(self) -> None:
418420
"""Sets .clustering_kwargs"""
419421
if self.clustering_algorithm == "phenograph":
420422
if "prune" not in self.clustering_kwargs:

doubletdetection/plot.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import os
22
import warnings
3+
from typing import Any
34

45
import matplotlib
56
import numpy as np
7+
from numpy.typing import NDArray
8+
from matplotlib.figure import Figure
69

710
try:
811
os.environ["DISPLAY"]
@@ -11,7 +14,7 @@
1114
import matplotlib.pyplot as plt
1215

1316

14-
def normalize_counts(raw_counts, pseudocount=0.1):
17+
def normalize_counts(raw_counts: NDArray, pseudocount: float = 0.1) -> NDArray:
1518
"""Normalize count array. Default normalizer used by BoostClassifier.
1619
1720
Args:
@@ -22,7 +25,6 @@ def normalize_counts(raw_counts, pseudocount=0.1):
2225
ndarray: Normalized data.
2326
"""
2427
# Sum across cells
25-
2628
cell_sums = np.sum(raw_counts, axis=1)
2729

2830
# Mutiply by median and divide each cell by cell sum
@@ -34,7 +36,13 @@ def normalize_counts(raw_counts, pseudocount=0.1):
3436
return normed
3537

3638

37-
def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):
39+
def convergence(
40+
clf: Any,
41+
show: bool = False,
42+
save: str | None = None,
43+
p_thresh: float = 1e-7,
44+
voter_thresh: float = 0.9,
45+
) -> Figure:
3846
"""Produce a plot showing number of cells called doublet per iter
3947
4048
Args:
@@ -81,15 +89,15 @@ def convergence(clf, show=False, save=None, p_thresh=1e-7, voter_thresh=0.9):
8189

8290

8391
def threshold(
84-
clf,
85-
show=False,
86-
save=None,
87-
log10=True,
88-
log_p_grid=None,
89-
voter_grid=None,
90-
v_step=2,
91-
p_step=5,
92-
):
92+
clf: Any,
93+
show: bool = False,
94+
save: str | None = None,
95+
log10: bool = True,
96+
log_p_grid: NDArray | None = None,
97+
voter_grid: NDArray | None = None,
98+
v_step: int = 2,
99+
p_step: int = 5,
100+
) -> Figure:
93101
"""Produce a plot showing number of cells called doublet across
94102
various thresholds
95103

0 commit comments

Comments
 (0)