Skip to content

Commit fc083a6

Browse files
MuhammedHasanMuhammed Hasan Celik
andauthored
faster modisco for decima (#23)
* modisco added * modisco slow version * io * add modisco-lite to setup dependencies * bug fix in bigwig writer * faster version * check path in env variables * avoid zero division warnning and load metadata ones * seqlet bed files * setup fix * fix for modisco cli doc * printing issue * motif fix * motif functions and seqlet calling update * attributions --------- Co-authored-by: Muhammed Hasan Celik <celik.muhammed_hasan@gene.com>
1 parent f9a5203 commit fc083a6

File tree

12 files changed

+399
-103
lines changed

12 files changed

+399
-103
lines changed

setup.cfg

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ install_requires =
6969
h5py
7070
pyBigWig
7171
pyarrow
72-
tangermeme>=1.0.0
7372
safetensors
74-
modisco-lite
73+
tangermeme>=1.0.0
74+
modisco-lite @ git+https://github.com/MuhammedHasan/tfmodisco-lite.git@faster-modisco
7575

7676
[options.packages.find]
7777
where = src

src/decima/cli/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,13 @@
88
from decima.cli.vep import cli_predict_variant_effect
99
from decima.cli.finetune import cli_finetune
1010
from decima.cli.vep import cli_vep_ensemble
11-
from decima.cli.modisco import cli_modisco_attributions, cli_modisco_patterns, cli_modisco_reports, cli_modisco
11+
from decima.cli.modisco import (
12+
cli_modisco_attributions,
13+
cli_modisco_patterns,
14+
cli_modisco_reports,
15+
cli_modisco_seqlet_bed,
16+
cli_modisco,
17+
)
1218

1319

1420
logger = logging.getLogger("decima")
@@ -28,7 +34,6 @@ def main():
2834
pass
2935

3036

31-
# main.add_command(cli_finetune, name="finetune")
3237
main.add_command(cli_predict_genes, name="predict-genes")
3338
main.add_command(cli_download, name="download")
3439
main.add_command(cli_attributions, name="attributions")
@@ -39,6 +44,7 @@ def main():
3944
main.add_command(cli_modisco_attributions, name="modisco-attributions")
4045
main.add_command(cli_modisco_patterns, name="modisco-patterns")
4146
main.add_command(cli_modisco_reports, name="modisco-reports")
47+
main.add_command(cli_modisco_seqlet_bed, name="modisco-seqlet-bed")
4248
main.add_command(cli_modisco, name="modisco")
4349

4450

src/decima/cli/modisco.py

Lines changed: 83 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
import click
22
from typing import List, Optional, Union
3-
from decima.interpret.modisco import predict_save_modisco_attributions, modisco_patterns, modisco_reports, modisco
3+
from decima.interpret.modisco import (
4+
predict_save_modisco_attributions,
5+
modisco_patterns,
6+
modisco_reports,
7+
modisco_seqlet_bed,
8+
modisco,
9+
)
410

511

612
@click.command()
@@ -31,7 +37,7 @@
3137
type=click.Choice(["specificity", "aggregate"]),
3238
default="specificity",
3339
show_default=True,
34-
help="Transform to use for the prediction.",
40+
help="Transform to use for attribution analysis.",
3541
)
3642
@click.option("--batch-size", type=int, default=2, show_default=True, help="Batch size for the prediction.")
3743
@click.option("--genes", type=str, default=None, help="Genes to predict. If not provided, all genes will be predicted.")
@@ -41,6 +47,10 @@
4147
default=None,
4248
help="Top n markers to predict. If not provided, all markers will be predicted.",
4349
)
50+
@click.option("--disable-bigwig", is_flag=True, help="Whether to disable bigwig file.")
51+
@click.option(
52+
"--disable-correct-grad-bigwig", is_flag=True, help="Whether to disable correct gradient for bigwig file."
53+
)
4454
@click.option("--device", type=str, default=None, help="Device to use. If not provided, the best device will be used.")
4555
@click.option(
4656
"--genome", type=str, default="hg38", show_default=True, help="Genome name or path to the genome fasta file."
@@ -57,6 +67,8 @@ def cli_modisco_attributions(
5767
batch_size: int = 4,
5868
genes: Optional[str] = None,
5969
top_n_markers: Optional[int] = None,
70+
disable_bigwig: bool = False,
71+
disable_correct_grad_bigwig: bool = False,
6072
device: Optional[str] = None,
6173
num_workers: int = 4,
6274
genome: str = "hg38",
@@ -81,6 +93,8 @@ def cli_modisco_attributions(
8193
batch_size=batch_size,
8294
genes=genes,
8395
top_n_markers=top_n_markers,
96+
bigwig=not disable_bigwig,
97+
correct_grad_bigwig=not disable_correct_grad_bigwig,
8498
num_workers=num_workers,
8599
device=device,
86100
genome=genome,
@@ -116,7 +130,8 @@ def cli_modisco_attributions(
116130
default=None,
117131
help="Top n markers to predict. If not provided, all markers will be predicted.",
118132
)
119-
# @click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
133+
@click.option("--correct-grad", type=bool, default=True, show_default=True, help="Whether to correct gradient.")
134+
@click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
120135
@click.option(
121136
"--max-seqlets", type=int, default=20_000, show_default=True, help="The maximum number of seqlets per metacluster."
122137
)
@@ -140,14 +155,14 @@ def cli_modisco_attributions(
140155
show_default=True,
141156
help="Additional flank added at the end of motif discovery.",
142157
)
143-
# @click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
144-
# @click.option(
145-
# "--pattern-type",
146-
# type=click.Choice(["both", "pos", "neg"]),
147-
# default="both",
148-
# show_default=True,
149-
# help="Which pattern signs to compute: both, pos, or neg.",
150-
# )
158+
@click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
159+
@click.option(
160+
"--pattern-type",
161+
type=click.Choice(["both", "pos", "neg"]),
162+
default="both",
163+
show_default=True,
164+
help="Which pattern signs to compute: both, pos, or neg.",
165+
)
151166
def cli_modisco_patterns(
152167
output_prefix: str,
153168
attributions: str,
@@ -157,7 +172,8 @@ def cli_modisco_patterns(
157172
metadata: Optional[str] = None,
158173
genes: Optional[List[str]] = None,
159174
top_n_markers: Optional[int] = None,
160-
# num_workers: int = 4,
175+
correct_grad: bool = True,
176+
num_workers: int = 4,
161177
# modisco parameters
162178
max_seqlets: int = 20_000,
163179
n_leiden: int = 16,
@@ -166,8 +182,8 @@ def cli_modisco_patterns(
166182
flank_size: int = 5,
167183
initial_flank_to_add: int = 10,
168184
final_flank_to_add: int = 0,
169-
# stranded: bool = False,
170-
# pattern_type: str = "both",
185+
stranded: bool = False,
186+
pattern_type: str = "both",
171187
):
172188
if isinstance(attributions, str):
173189
attributions = attributions.split(",")
@@ -184,7 +200,8 @@ def cli_modisco_patterns(
184200
metadata_anndata=metadata,
185201
genes=genes,
186202
top_n_markers=top_n_markers,
187-
# num_workers=num_workers,
203+
correct_grad=correct_grad,
204+
num_workers=num_workers,
188205
# modisco parameters
189206
max_seqlets_per_metacluster=max_seqlets,
190207
n_leiden_runs=n_leiden,
@@ -193,8 +210,8 @@ def cli_modisco_patterns(
193210
flank_size=flank_size,
194211
initial_flank_to_add=initial_flank_to_add,
195212
final_flank_to_add=final_flank_to_add,
196-
# stranded=stranded,
197-
# pattern_type=pattern_type,
213+
stranded=stranded,
214+
pattern_type=pattern_type,
198215
)
199216

200217

@@ -210,7 +227,7 @@ def cli_modisco_patterns(
210227
@click.option("--trim-threshold", type=float, default=0.3, show_default=True, help="Trim threshold.")
211228
@click.option("--trim-min-length", type=int, default=3, show_default=True, help="Trim minimum length.")
212229
@click.option("--tomtomlite", type=bool, default=False, show_default=True, help="Whether to use TomtomLite.")
213-
# @click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
230+
@click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
214231
def cli_modisco_reports(
215232
output_prefix: str,
216233
modisco_h5: str,
@@ -221,7 +238,7 @@ def cli_modisco_reports(
221238
trim_threshold: float,
222239
trim_min_length: int,
223240
tomtomlite: bool,
224-
# num_workers: int,
241+
num_workers: int,
225242
):
226243
modisco_reports(
227244
output_prefix=output_prefix,
@@ -233,7 +250,26 @@ def cli_modisco_reports(
233250
trim_threshold=trim_threshold,
234251
trim_min_length=trim_min_length,
235252
tomtomlite=tomtomlite,
236-
# num_workers=num_workers,
253+
num_workers=num_workers,
254+
)
255+
256+
257+
@click.command()
258+
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files.")
259+
@click.option("--modisco_h5", type=click.Path(exists=True), required=True, help="Path to the modisco HDF5 file.")
260+
@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.")
261+
@click.option("--trim-threshold", type=float, default=0.2, show_default=True, help="Trim threshold.")
262+
def cli_modisco_seqlet_bed(
263+
output_prefix: str,
264+
modisco_h5: str,
265+
metadata: Optional[str] = None,
266+
trim_threshold: float = 0.2,
267+
):
268+
modisco_seqlet_bed(
269+
output_prefix=output_prefix,
270+
modisco_h5=modisco_h5,
271+
metadata_anndata=metadata,
272+
trim_threshold=trim_threshold,
237273
)
238274

239275

@@ -276,6 +312,7 @@ def cli_modisco_reports(
276312
default=None,
277313
help="Top n markers to predict. If not provided, all markers will be predicted.",
278314
)
315+
@click.option("--correct-grad", type=bool, default=True, show_default=True, help="Whether to correct gradient.")
279316
@click.option(
280317
"--device",
281318
type=str,
@@ -310,14 +347,14 @@ def cli_modisco_reports(
310347
show_default=True,
311348
help="Additional flank added at the end of motif discovery.",
312349
)
313-
# @click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
314-
# @click.option(
315-
# "--pattern-type",
316-
# type=click.Choice(["both", "pos", "neg"]),
317-
# default="both",
318-
# show_default=True,
319-
# help="Which pattern signs to compute: both, pos, or neg.",
320-
# )
350+
@click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
351+
@click.option(
352+
"--pattern-type",
353+
type=click.Choice(["both", "pos", "neg"]),
354+
default="both",
355+
show_default=True,
356+
help="Which pattern signs to compute: both, pos, or neg.",
357+
)
321358
@click.option(
322359
"--meme-motif-db", type=str, default="hocomoco_v13", show_default=True, help="Path to the MEME motif database."
323360
)
@@ -327,6 +364,13 @@ def cli_modisco_reports(
327364
@click.option("--trim-threshold", type=float, default=0.3, show_default=True, help="Trim threshold.")
328365
@click.option("--trim-min-length", type=int, default=3, show_default=True, help="Trim minimum length.")
329366
@click.option("--tomtomlite", type=bool, default=False, show_default=True, help="Whether to use TomtomLite.")
367+
@click.option(
368+
"--seqlet-motif-trim-threshold",
369+
type=float,
370+
default=0.2,
371+
show_default=True,
372+
help="Trim threshold for motifs in seqlets bed file.",
373+
)
330374
def cli_modisco(
331375
output_prefix: str,
332376
tasks: Optional[List[str]] = None,
@@ -336,8 +380,9 @@ def cli_modisco(
336380
metadata: Optional[str] = None,
337381
method: str = "saliency",
338382
batch_size: int = 4,
339-
genes: Optional[List[str]] = None, # TODO: list of genes
383+
genes: Optional[str] = None,
340384
top_n_markers: Optional[int] = None,
385+
correct_grad: bool = True,
341386
device: Optional[str] = None,
342387
num_workers: int = 4,
343388
genome: str = "hg38",
@@ -349,8 +394,8 @@ def cli_modisco(
349394
flank_size: int = 5,
350395
initial_flank_to_add: int = 10,
351396
final_flank_to_add: int = 0,
352-
# stranded: bool = False,
353-
# pattern_type: str = "both",
397+
stranded: bool = False,
398+
pattern_type: str = "both",
354399
# reports parameters
355400
meme_motif_db: str = "hocomoco_v13",
356401
img_path_suffix: str = "",
@@ -359,6 +404,8 @@ def cli_modisco(
359404
trim_threshold: float = 0.3,
360405
trim_min_length: int = 3,
361406
tomtomlite: bool = False,
407+
# seqlet thresholds
408+
seqlet_motif_trim_threshold: float = 0.2,
362409
):
363410
if model in ["0", "1", "2", "3"]:
364411
model = int(model)
@@ -380,6 +427,7 @@ def cli_modisco(
380427
batch_size=batch_size,
381428
genes=genes,
382429
top_n_markers=top_n_markers,
430+
correct_grad=correct_grad,
383431
device=device,
384432
num_workers=num_workers,
385433
genome=genome,
@@ -391,8 +439,8 @@ def cli_modisco(
391439
flank_size=flank_size,
392440
initial_flank_to_add=initial_flank_to_add,
393441
final_flank_to_add=final_flank_to_add,
394-
# stranded=stranded,
395-
# pattern_type=pattern_type,
442+
stranded=stranded,
443+
pattern_type=pattern_type,
396444
# reports parameters
397445
img_path_suffix=img_path_suffix,
398446
meme_motif_db=meme_motif_db,
@@ -401,4 +449,6 @@ def cli_modisco(
401449
trim_threshold=trim_threshold,
402450
trim_min_length=trim_min_length,
403451
tomtomlite=tomtomlite,
452+
# seqlet thresholds
453+
seqlet_motif_trim_threshold=seqlet_motif_trim_threshold,
404454
)

0 commit comments

Comments
 (0)