Skip to content

Commit b02eb55

Browse files
committed
added cellsweep base
1 parent 0391b45 commit b02eb55

File tree

2 files changed

+284
-5
lines changed

2 files changed

+284
-5
lines changed

kb_python/count.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,60 @@
8686

8787
INSPECT_PARSER = re.compile(r'^.*?(?P<count>[0-9]+)')
8888

89+
def check_kwargs(kwargs):
90+
if not kwargs:
91+
return
92+
try:
93+
import inspect as inspect_module
94+
import cellsweep
95+
sig = inspect_module.signature(cellsweep.denoise_count_matrix.__wrapped__)
96+
cellsweep_arg_names = list(sig.parameters.keys())
97+
if any(param not in cellsweep_arg_names for param in kwargs):
98+
invalid_params = [
99+
param for param in kwargs
100+
if param not in cellsweep_arg_names
101+
]
102+
raise TypeError(
103+
f"count() got an unexpected keyword argument(s): {', '.join(invalid_params)}"
104+
)
105+
except Exception as e:
106+
pass
107+
108+
def run_cellsweep(counts_dir, out_dir, threads=2, kwargs=None):
109+
try:
110+
import inspect as inspect_module
111+
import cellsweep
112+
sig = inspect_module.signature(cellsweep.denoise_count_matrix.__wrapped__)
113+
cellsweep_arg_names = list(sig.parameters.keys())
114+
cellsweep_kwargs = {}
115+
if kwargs:
116+
cellsweep_kwargs = {k: v for k, v in kwargs.items() if k in cellsweep_arg_names}
117+
cellsweep_counts_dir = os.path.join(out_dir, "counts_swept")
118+
cellsweep_adata_path = os.path.join(cellsweep_counts_dir, "swept_adata.h5ad")
119+
120+
matrix_path = os.path.join(counts_dir, "counts_unfiltered", "cells_x_genes.mtx")
121+
barcodes_path = os.path.join(counts_dir, "counts_unfiltered", "cells_x_genes.barcodes.txt")
122+
genes_path = os.path.join(counts_dir, "counts_unfiltered", "cells_x_genes.genes.names.txt")
123+
adata = import_matrix_as_anndata(matrix_path, barcodes_path, genes_path)
124+
# adata = cellsweep.utils.read_kb_mtx_as_adata(counts_dir)
125+
126+
# TODO:
127+
#* 1. think of how to do automatic celltyping
128+
#* 2. implement the requirement for expected_cells or umi_cutoff, or have a way to auto-detect
129+
130+
# add celltypes
131+
# adata = cs_utils.determine_cell_types(adata, model_pkl=model_pkl, filter_empty=True, expected_cells=expected_cells, celltypist_convert=celltypist_convert, celltypist_map_file=celltypist_map_file, verbose=verbose)
132+
133+
_ = cellsweep.denoise_count_matrix(
134+
adata=adata,
135+
adata_out=cellsweep_adata_path,
136+
threads=threads,
137+
**cellsweep_kwargs
138+
)
139+
return cellsweep_adata_path
140+
except Exception as e:
141+
logger.error(f"Error running cellsweep: {e}")
142+
return None
89143

90144
def make_transcript_t2g(txnames_path: str, out_path: str) -> str:
91145
"""Make a two-column t2g file from a transcripts file
@@ -1255,6 +1309,8 @@ def count(
12551309
quant_umis: bool = False,
12561310
keep_flags: bool = False,
12571311
exact_barcodes: bool = False,
1312+
remove_ambient: bool = False,
1313+
**kwargs
12581314
) -> Dict[str, Union[str, Dict[str, str]]]:
12591315
"""Generates count matrices for single-cell RNA seq.
12601316
@@ -1332,13 +1388,18 @@ def count(
13321388
quant_umis: Whether to run quant-tcc when there are UMIs, defaults to `False`
13331389
keep_flags: Preserve flag column when sorting BUS file, defaults to `False`
13341390
exact_barcodes: Use exact match for 'correcting' barcodes to on-list, defaults to `False`
1391+
remove_ambient: Whether to remove ambient RNA using CellSweep, defaults to `False`
1392+
**kwargs: Additional keyword arguments to pass to CellSweep
13351393
13361394
Returns:
13371395
Dictionary containing paths to generated files
13381396
"""
13391397
STATS.start()
13401398
is_batch = isinstance(fastqs, str)
13411399

1400+
#* kwargs is only added for cellsweep, so check accordingly
1401+
check_kwargs(kwargs)
1402+
13421403
results = {}
13431404
make_directory(out_dir)
13441405
unfiltered_results = results.setdefault('unfiltered', {})
@@ -1755,6 +1816,10 @@ def update_results_with_suffix(current_results, new_results, suffix):
17551816
temp_dir=temp_dir
17561817
)
17571818
unfiltered_results.update(report_result)
1819+
1820+
if remove_ambient:
1821+
logger.info('Removing ambient RNA using CellSweep')
1822+
results['swept_counts'] = run_cellsweep(counts_dir=counts_dir, out_dir=out_dir, threads=threads, kwargs=kwargs)
17581823

17591824
# Delete intermediate BUS files if requested
17601825
if delete_bus:
@@ -1841,6 +1906,8 @@ def count_nac(
18411906
quant_umis: bool = False,
18421907
keep_flags: bool = False,
18431908
exact_barcodes: bool = False,
1909+
remove_ambient: bool = False,
1910+
**kwargs
18441911
) -> Dict[str, Union[Dict[str, str], str]]:
18451912
"""Generates RNA velocity matrices for single-cell RNA seq.
18461913
@@ -1917,13 +1984,18 @@ def count_nac(
19171984
quant_umis: Whether to run quant-tcc when there are UMIs, defaults to `False`
19181985
keep_flags: Preserve flag column when sorting BUS file, defaults to `False`
19191986
exact_barcodes: Use exact match for 'correcting' barcodes to on-list, defaults to `False`
1987+
remove_ambient: Whether to remove ambient RNA using CellSweep, defaults to `False`
1988+
**kwargs: Additional keyword arguments to pass to CellSweep
19201989
19211990
Returns:
19221991
Dictionary containing path to generated index
19231992
"""
19241993
STATS.start()
19251994
is_batch = isinstance(fastqs, str)
19261995

1996+
#* kwargs is only added for cellsweep, so check accordingly
1997+
check_kwargs(kwargs)
1998+
19271999
results = {}
19282000
make_directory(out_dir)
19292001
unfiltered_results = results.setdefault('unfiltered', {})
@@ -2483,6 +2555,10 @@ def update_results_with_suffix(current_results, new_results, suffix):
24832555
logger.warning(
24842556
'Plots for TCC matrices have not yet been implemented. The HTML report will not contain any plots.'
24852557
)
2558+
2559+
if remove_ambient:
2560+
logger.info('Removing ambient RNA using CellSweep')
2561+
results['swept_counts'] = run_cellsweep(counts_dir=counts_dir, out_dir=out_dir, threads=threads, kwargs=kwargs)
24862562

24872563
# Delete intermediate BUS files if requested
24882564
if delete_bus:
@@ -2541,6 +2617,8 @@ def count_velocity(
25412617
strand: Optional[Literal['unstranded', 'forward', 'reverse']] = None,
25422618
umi_gene: bool = False,
25432619
em: bool = False,
2620+
remove_ambient: bool = False,
2621+
**kwargs
25442622
) -> Dict[str, Union[Dict[str, str], str]]:
25452623
"""Generates RNA velocity matrices (DEPRECATED).
25462624
@@ -2588,7 +2666,8 @@ def count_velocity(
25882666
`False`
25892667
em: Whether to estimate gene abundances using EM algorithm, defaults to
25902668
`False`
2591-
2669+
remove_ambient: Whether to remove ambient RNA using CellSweep, defaults to `False`
2670+
**kwargs: Additional keyword arguments to pass to CellSweep
25922671
Returns:
25932672
Dictionary containing path to generated index
25942673
"""
@@ -2597,6 +2676,9 @@ def count_velocity(
25972676
BUS_CDNA_PREFIX = 'spliced'
25982677
BUS_INTRON_PREFIX = 'unspliced'
25992678

2679+
#* kwargs is only added for cellsweep, so check accordingly
2680+
check_kwargs(kwargs)
2681+
26002682
results = {}
26012683
make_directory(out_dir)
26022684
unfiltered_results = results.setdefault('unfiltered', {})
@@ -2893,6 +2975,10 @@ def count_velocity(
28932975
stats_path = STATS.save(os.path.join(out_dir, KB_INFO_FILENAME))
28942976
results.update({'stats': stats_path})
28952977

2978+
if remove_ambient:
2979+
logger.info('Removing ambient RNA using CellSweep')
2980+
results['swept_counts'] = run_cellsweep(counts_dir=counts_dir, out_dir=out_dir, threads=threads, kwargs=kwargs)
2981+
28962982
# Reports
28972983
nb_path = os.path.join(out_dir, REPORT_NOTEBOOK_FILENAME)
28982984
html_path = os.path.join(out_dir, REPORT_HTML_FILENAME)
@@ -2962,6 +3048,8 @@ def count_velocity_smartseq3(
29623048
by_name: bool = False,
29633049
inspect: bool = True,
29643050
strand: Optional[Literal['unstranded', 'forward', 'reverse']] = None,
3051+
remove_ambient: bool = False,
3052+
**kwargs
29653053
) -> Dict[str, Union[str, Dict[str, str]]]:
29663054
"""Generates count matrices for Smartseq3 (DEPRECATED).
29673055
@@ -2988,6 +3076,8 @@ def count_velocity_smartseq3(
29883076
inspect: Whether or not to inspect the output BUS file and generate
29893077
the inspect.json
29903078
strand: Strandedness, defaults to `None`
3079+
remove_ambient: Whether to remove ambient RNA using CellSweep, defaults to `False`
3080+
**kwargs: Additional keyword arguments to pass to CellSweep
29913081
29923082
Returns:
29933083
Dictionary containing paths to generated files
@@ -2997,6 +3087,9 @@ def count_velocity_smartseq3(
29973087
BUS_CDNA_PREFIX = 'spliced'
29983088
BUS_INTRON_PREFIX = 'unspliced'
29993089

3090+
#* kwargs is only added for cellsweep, so check accordingly
3091+
check_kwargs(kwargs)
3092+
30003093
results = {}
30013094
make_directory(out_dir)
30023095
unfiltered_results = results.setdefault('unfiltered', {})

0 commit comments

Comments
 (0)