Skip to content

Commit 3534bd5

Browse files
authored
Merge pull request #136 from JonathanShor/adam_changes
move to poetry for dist, remove excess code, use pip installable pheno
2 parents b784798 + efdd775 commit 3534bd5

File tree

11 files changed

+3006
-453
lines changed

11 files changed

+3006
-453
lines changed

.github/workflows/test.yml

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# This workflow will install Python dependencies, run tests and lint with a variety of Python versions
2+
# For more information see: https://help.github.com/actions/language-and-framework-guides/using-python-with-github-actions
3+
4+
name: doubletdetection
5+
6+
on:
7+
push:
8+
branches: [master]
9+
pull_request:
10+
branches: [master]
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
strategy:
17+
matrix:
18+
python-version: [3.6, 3.7]
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
- name: Set up Python ${{ matrix.python-version }}
23+
uses: actions/setup-python@v2
24+
with:
25+
python-version: ${{ matrix.python-version }}
26+
- name: Cache pip
27+
uses: actions/cache@v2
28+
with:
29+
path: ~/.cache/pip
30+
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
31+
restore-keys: |
32+
${{ runner.os }}-pip-
33+
- name: Install dependencies
34+
run: |
35+
pip install --quiet .[dev]
36+
- name: Lint with flake8
37+
run: |
38+
flake8
39+
- name: Format with black
40+
run: |
41+
black --check .
42+
- name: Test with pytest
43+
run: |
44+
pytest

README.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,21 @@
22

33
[![DOI](https://zenodo.org/badge/86256007.svg)](https://zenodo.org/badge/latestdoi/86256007)
44
[![Documentation Status](https://readthedocs.org/projects/doubletdetection/badge/?version=latest)](https://doubletdetection.readthedocs.io/en/latest/?badge=latest)
5+
[![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/python/black)
6+
![Build Status](https://github.com/JonathanShor/DoubletDetection/workflows/doubletdetection/badge.svg)
57

68
DoubletDetection is a Python3 package to detect doublets (technical errors) in single-cell RNA-seq count matrices.
79

810
## Installing DoubletDetection
911

12+
Install from PyPI
13+
14+
```bash
15+
pip install doubletdetection
16+
```
17+
18+
Install from source
19+
1020
```bash
1121
git clone https://github.com/JonathanShor/DoubletDetection.git
1222
cd DoubletDetection
@@ -41,7 +51,7 @@ The classifier works best when
4151

4252
In `v2.5` we have added a new experimental clustering method (`scanpy`'s Louvain clustering) that is much faster than phenograph. We are still validating results from this new clustering. Please see the notebook below for an example of using this new feature.
4353

44-
See our [jupyter notebook](https://nbviewer.jupyter.org/github/JonathanShor/DoubletDetection/blob/master/tests/notebooks/PBMC_8k_vignette.ipynb) for an example on 8k PBMCs from 10x.
54+
See our [jupyter notebook](https://nbviewer.jupyter.org/github/JonathanShor/DoubletDetection/blob/master/tests/notebooks/PBMC_10k_vignette.ipynb) for an example on 8k PBMCs from 10x.
4555

4656
## Obtaining data
4757

doubletdetection/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,14 @@
1-
from .doubletdetection import BoostClassifier, load_mtx, load_10x_h5
1+
from .doubletdetection import BoostClassifier
22
from . import plot
3+
4+
5+
# https://github.com/python-poetry/poetry/pull/2366#issuecomment-652418094
6+
# https://github.com/python-poetry/poetry/issues/144#issuecomment-623927302
7+
try:
8+
import importlib.metadata as importlib_metadata
9+
except ModuleNotFoundError:
10+
import importlib_metadata
11+
package_name = "doubletdetection"
12+
__version__ = importlib_metadata.version(package_name)
13+
14+
__all__ = ["BoostClassifier", "plot"]

doubletdetection/doubletdetection.py

Lines changed: 35 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,70 +1,22 @@
11
"""Doublet detection in single-cell RNA-seq data."""
22

33
import collections
4+
import io
45
import warnings
6+
from contextlib import redirect_stdout
57

68
import anndata
79
import numpy as np
810
import phenograph
911
import scipy.sparse as sp_sparse
10-
import tables
1112
import scanpy as sc
12-
from scipy.io import mmread
1313
from scipy.sparse import csr_matrix
1414
from scipy.stats import hypergeom
1515
from sklearn.utils import check_array
1616
from sklearn.utils.sparsefuncs_fast import inplace_csr_row_normalize_l1
1717
from tqdm.auto import tqdm
1818

1919

20-
def load_10x_h5(file, genome):
21-
"""Load count matrix in 10x H5 format
22-
Adapted from:
23-
https://support.10xgenomics.com/single-cell-gene-expression/software/
24-
pipelines/latest/advanced/h5_matrices
25-
26-
Args:
27-
file (str): Path to H5 file
28-
genome (str): genome, top level h5 group
29-
30-
Returns:
31-
ndarray: Raw count matrix.
32-
ndarray: Barcodes
33-
ndarray: Gene names
34-
"""
35-
36-
with tables.open_file(file, "r") as f:
37-
try:
38-
group = f.get_node(f.root, genome)
39-
except tables.NoSuchNodeError:
40-
print("That genome does not exist in this file.")
41-
return None
42-
# gene_ids = getattr(group, 'genes').read()
43-
gene_names = getattr(group, "gene_names").read()
44-
barcodes = getattr(group, "barcodes").read()
45-
data = getattr(group, "data").read()
46-
indices = getattr(group, "indices").read()
47-
indptr = getattr(group, "indptr").read()
48-
shape = getattr(group, "shape").read()
49-
matrix = sp_sparse.csc_matrix((data, indices, indptr), shape=shape)
50-
51-
return matrix, barcodes, gene_names
52-
53-
54-
def load_mtx(file):
55-
"""Load count matrix in mtx format
56-
57-
Args:
58-
file (str): Path to mtx file
59-
60-
Returns:
61-
ndarray: Raw count matrix.
62-
"""
63-
raw_counts = np.transpose(mmread(file))
64-
65-
return raw_counts.tocsc()
66-
67-
6820
class BoostClassifier:
6921
"""Classifier for doublets in single-cell RNA-seq data.
7022
@@ -162,8 +114,6 @@ def __init__(
162114
if use_phenograph is True:
163115
if "prune" not in phenograph_parameters:
164116
phenograph_parameters["prune"] = True
165-
if ("verbosity" not in phenograph_parameters) and (not self.verbose):
166-
phenograph_parameters["verbosity"] = 1
167117
self.phenograph_parameters = phenograph_parameters
168118
if (self.n_iters == 1) and (phenograph_parameters.get("prune") is True):
169119
warn_msg = (
@@ -238,9 +188,7 @@ def fit(self, raw_counts):
238188
self.all_log_p_values_ = np.zeros((self.n_iters, self._num_cells))
239189
all_communities = np.zeros((self.n_iters, self._num_cells))
240190
all_parents = []
241-
all_synth_communities = np.zeros(
242-
(self.n_iters, int(self.boost_rate * self._num_cells))
243-
)
191+
all_synth_communities = np.zeros((self.n_iters, int(self.boost_rate * self._num_cells)))
244192

245193
for i in tqdm(range(self.n_iters)):
246194
if self.verbose:
@@ -284,9 +232,7 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
284232
"""
285233
log_p_thresh = np.log(p_thresh)
286234
if self.n_iters > 1:
287-
with np.errstate(
288-
invalid="ignore"
289-
): # Silence numpy warning about NaN comparison
235+
with np.errstate(invalid="ignore"): # Silence numpy warning about NaN comparison
290236
self.voting_average_ = np.mean(
291237
np.ma.masked_invalid(self.all_log_p_values_) <= log_p_thresh, axis=0
292238
)
@@ -298,20 +244,34 @@ def predict(self, p_thresh=1e-7, voter_thresh=0.9):
298244
# Find a cutoff score
299245
potential_cutoffs = np.unique(self.all_scores_[~np.isnan(self.all_scores_)])
300246
if len(potential_cutoffs) > 1:
301-
max_dropoff = (
302-
np.argmax(potential_cutoffs[1:] - potential_cutoffs[:-1]) + 1
303-
)
247+
max_dropoff = np.argmax(potential_cutoffs[1:] - potential_cutoffs[:-1]) + 1
304248
else: # Most likely pathological dataset, only one (or no) clusters
305249
max_dropoff = 0
306250
self.suggested_score_cutoff_ = potential_cutoffs[max_dropoff]
307-
with np.errstate(
308-
invalid="ignore"
309-
): # Silence numpy warning about NaN comparison
251+
with np.errstate(invalid="ignore"): # Silence numpy warning about NaN comparison
310252
self.labels_ = self.all_scores_[0, :] >= self.suggested_score_cutoff_
311253
self.labels_[np.isnan(self.all_scores_)[0, :]] = np.nan
312254

313255
return self.labels_
314256

257+
def doublet_score(self):
258+
"""Produce doublet scores
259+
260+
The doublet score is the average negative log p-value of doublet enrichment
261+
averaged over the iterations. Higher means more likely to be doublet.
262+
263+
Returns:
264+
scores (ndarray, ndims=1): Average negative log p-value over iterations
265+
"""
266+
267+
if self.n_iters > 1:
268+
with np.errstate(invalid="ignore"): # Silence numpy warning about NaN comparison
269+
avg_log_p = np.mean(np.ma.masked_invalid(self.all_log_p_values_), axis=0)
270+
else:
271+
avg_log_p = self.all_log_p_values_[0]
272+
273+
return -avg_log_p
274+
315275
def _one_fit(self):
316276
if self.verbose:
317277
print("\nCreating synthetic doublets...")
@@ -347,19 +307,22 @@ def _one_fit(self):
347307
if self.verbose:
348308
print("Clustering augmented data set...\n")
349309
if self.use_phenograph:
350-
fullcommunities, _, _ = phenograph.cluster(
351-
aug_counts.obsm["X_pca"], **self.phenograph_parameters
352-
)
310+
f = io.StringIO()
311+
with redirect_stdout(f):
312+
fullcommunities, _, _ = phenograph.cluster(
313+
aug_counts.obsm["X_pca"], **self.phenograph_parameters
314+
)
315+
out = f.getvalue()
316+
if self.verbose:
317+
print(out)
353318
else:
354319
sc.pp.neighbors(
355320
aug_counts,
356321
random_state=self.random_state,
357322
method="umap",
358323
n_neighbors=10,
359324
)
360-
sc.tl.louvain(
361-
aug_counts, random_state=self.random_state, resolution=4, directed=False
362-
)
325+
sc.tl.louvain(aug_counts, random_state=self.random_state, resolution=4, directed=False)
363326
fullcommunities = np.array(aug_counts.obs["louvain"], dtype=int)
364327
min_ID = min(fullcommunities)
365328
self.communities_ = fullcommunities[: self._num_cells]
@@ -380,8 +343,7 @@ def _one_fit(self):
380343
orig_cells_per_comm = collections.Counter(self.communities_)
381344
community_IDs = orig_cells_per_comm.keys()
382345
community_scores = {
383-
i: float(synth_cells_per_comm[i])
384-
/ (synth_cells_per_comm[i] + orig_cells_per_comm[i])
346+
i: float(synth_cells_per_comm[i]) / (synth_cells_per_comm[i] + orig_cells_per_comm[i])
385347
for i in community_IDs
386348
}
387349
scores = np.array([community_scores[i] for i in self.communities_])
@@ -412,9 +374,7 @@ def _createDoublets(self):
412374
num_synths = int(self.boost_rate * self._num_cells)
413375

414376
# Parent indices
415-
choices = np.random.choice(
416-
self._num_cells, size=(num_synths, 2), replace=self.replace
417-
)
377+
choices = np.random.choice(self._num_cells, size=(num_synths, 2), replace=self.replace)
418378
parents = [list(p) for p in choices]
419379

420380
parent0 = self._raw_counts[choices[:, 0], :]

0 commit comments

Comments
 (0)