Skip to content
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ install_requires =
h5py
pyBigWig
pyarrow
tangermeme>=1.0.0
safetensors
modisco-lite
tangermeme>=1.0.0
modisco-lite @ git+https://github.com/MuhammedHasan/tfmodisco-lite.git@faster-modisco

[options.packages.find]
where = src
Expand Down
10 changes: 8 additions & 2 deletions src/decima/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@
from decima.cli.vep import cli_predict_variant_effect
from decima.cli.finetune import cli_finetune
from decima.cli.vep import cli_vep_ensemble
from decima.cli.modisco import cli_modisco_attributions, cli_modisco_patterns, cli_modisco_reports, cli_modisco
from decima.cli.modisco import (
cli_modisco_attributions,
cli_modisco_patterns,
cli_modisco_reports,
cli_modisco_seqlet_bed,
cli_modisco,
)


logger = logging.getLogger("decima")
Expand All @@ -28,7 +34,6 @@ def main():
pass


# main.add_command(cli_finetune, name="finetune")
main.add_command(cli_predict_genes, name="predict-genes")
main.add_command(cli_download, name="download")
main.add_command(cli_attributions, name="attributions")
Expand All @@ -39,6 +44,7 @@ def main():
main.add_command(cli_modisco_attributions, name="modisco-attributions")
main.add_command(cli_modisco_patterns, name="modisco-patterns")
main.add_command(cli_modisco_reports, name="modisco-reports")
main.add_command(cli_modisco_seqlet_bed, name="modisco-seqlet-bed")
main.add_command(cli_modisco, name="modisco")


Expand Down
116 changes: 83 additions & 33 deletions src/decima/cli/modisco.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
import click
from typing import List, Optional, Union
from decima.interpret.modisco import predict_save_modisco_attributions, modisco_patterns, modisco_reports, modisco
from decima.interpret.modisco import (
predict_save_modisco_attributions,
modisco_patterns,
modisco_reports,
modisco_seqlet_bed,
modisco,
)


@click.command()
Expand Down Expand Up @@ -31,7 +37,7 @@
type=click.Choice(["specificity", "aggregate"]),
default="specificity",
show_default=True,
help="Transform to use for the prediction.",
help="Transform to use for attribution analysis.",
)
@click.option("--batch-size", type=int, default=2, show_default=True, help="Batch size for the prediction.")
@click.option("--genes", type=str, default=None, help="Genes to predict. If not provided, all genes will be predicted.")
Expand All @@ -41,6 +47,10 @@
default=None,
help="Top n markers to predict. If not provided, all markers will be predicted.",
)
@click.option("--disable-bigwig", is_flag=True, help="Whether to disable bigwig file.")
@click.option(
"--disable-correct-grad-bigwig", is_flag=True, help="Whether to disable correct gradient for bigwig file."
)
@click.option("--device", type=str, default=None, help="Device to use. If not provided, the best device will be used.")
@click.option(
"--genome", type=str, default="hg38", show_default=True, help="Genome name or path to the genome fasta file."
Expand All @@ -57,6 +67,8 @@ def cli_modisco_attributions(
batch_size: int = 4,
genes: Optional[str] = None,
top_n_markers: Optional[int] = None,
disable_bigwig: bool = False,
disable_correct_grad_bigwig: bool = False,
device: Optional[str] = None,
num_workers: int = 4,
genome: str = "hg38",
Expand All @@ -81,6 +93,8 @@ def cli_modisco_attributions(
batch_size=batch_size,
genes=genes,
top_n_markers=top_n_markers,
bigwig=not disable_bigwig,
correct_grad_bigwig=not disable_correct_grad_bigwig,
num_workers=num_workers,
device=device,
genome=genome,
Expand Down Expand Up @@ -116,7 +130,8 @@ def cli_modisco_attributions(
default=None,
help="Top n markers to predict. If not provided, all markers will be predicted.",
)
# @click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
@click.option("--correct-grad", type=bool, default=True, show_default=True, help="Whether to correct gradient.")
@click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
@click.option(
"--max-seqlets", type=int, default=20_000, show_default=True, help="The maximum number of seqlets per metacluster."
)
Expand All @@ -140,14 +155,14 @@ def cli_modisco_attributions(
show_default=True,
help="Additional flank added at the end of motif discovery.",
)
# @click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
# @click.option(
# "--pattern-type",
# type=click.Choice(["both", "pos", "neg"]),
# default="both",
# show_default=True,
# help="Which pattern signs to compute: both, pos, or neg.",
# )
@click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
@click.option(
"--pattern-type",
type=click.Choice(["both", "pos", "neg"]),
default="both",
show_default=True,
help="Which pattern signs to compute: both, pos, or neg.",
)
def cli_modisco_patterns(
output_prefix: str,
attributions: str,
Expand All @@ -157,7 +172,8 @@ def cli_modisco_patterns(
metadata: Optional[str] = None,
genes: Optional[List[str]] = None,
top_n_markers: Optional[int] = None,
# num_workers: int = 4,
correct_grad: bool = True,
num_workers: int = 4,
# modisco parameters
max_seqlets: int = 20_000,
n_leiden: int = 16,
Expand All @@ -166,8 +182,8 @@ def cli_modisco_patterns(
flank_size: int = 5,
initial_flank_to_add: int = 10,
final_flank_to_add: int = 0,
# stranded: bool = False,
# pattern_type: str = "both",
stranded: bool = False,
pattern_type: str = "both",
):
if isinstance(attributions, str):
attributions = attributions.split(",")
Expand All @@ -184,7 +200,8 @@ def cli_modisco_patterns(
metadata_anndata=metadata,
genes=genes,
top_n_markers=top_n_markers,
# num_workers=num_workers,
correct_grad=correct_grad,
num_workers=num_workers,
# modisco parameters
max_seqlets_per_metacluster=max_seqlets,
n_leiden_runs=n_leiden,
Expand All @@ -193,8 +210,8 @@ def cli_modisco_patterns(
flank_size=flank_size,
initial_flank_to_add=initial_flank_to_add,
final_flank_to_add=final_flank_to_add,
# stranded=stranded,
# pattern_type=pattern_type,
stranded=stranded,
pattern_type=pattern_type,
)


Expand All @@ -210,7 +227,7 @@ def cli_modisco_patterns(
@click.option("--trim-threshold", type=float, default=0.3, show_default=True, help="Trim threshold.")
@click.option("--trim-min-length", type=int, default=3, show_default=True, help="Trim minimum length.")
@click.option("--tomtomlite", type=bool, default=False, show_default=True, help="Whether to use TomtomLite.")
# @click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
@click.option("--num-workers", type=int, default=4, show_default=True, help="Number of workers for the prediction.")
def cli_modisco_reports(
output_prefix: str,
modisco_h5: str,
Expand All @@ -221,7 +238,7 @@ def cli_modisco_reports(
trim_threshold: float,
trim_min_length: int,
tomtomlite: bool,
# num_workers: int,
num_workers: int,
):
modisco_reports(
output_prefix=output_prefix,
Expand All @@ -233,7 +250,26 @@ def cli_modisco_reports(
trim_threshold=trim_threshold,
trim_min_length=trim_min_length,
tomtomlite=tomtomlite,
# num_workers=num_workers,
num_workers=num_workers,
)


@click.command()
@click.option("-o", "--output-prefix", type=str, required=True, help="Prefix path to the output files.")
@click.option("--modisco_h5", type=click.Path(exists=True), required=True, help="Path to the modisco HDF5 file.")
@click.option("--metadata", type=str, default=None, help="Path to the metadata anndata file.")
@click.option("--trim-threshold", type=float, default=0.2, show_default=True, help="Trim threshold.")
def cli_modisco_seqlet_bed(
output_prefix: str,
modisco_h5: str,
metadata: Optional[str] = None,
trim_threshold: float = 0.2,
):
modisco_seqlet_bed(
output_prefix=output_prefix,
modisco_h5=modisco_h5,
metadata_anndata=metadata,
trim_threshold=trim_threshold,
)


Expand Down Expand Up @@ -276,6 +312,7 @@ def cli_modisco_reports(
default=None,
help="Top n markers to predict. If not provided, all markers will be predicted.",
)
@click.option("--correct-grad", type=bool, default=True, show_default=True, help="Whether to correct gradient.")
@click.option(
"--device",
type=str,
Expand Down Expand Up @@ -310,14 +347,14 @@ def cli_modisco_reports(
show_default=True,
help="Additional flank added at the end of motif discovery.",
)
# @click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
# @click.option(
# "--pattern-type",
# type=click.Choice(["both", "pos", "neg"]),
# default="both",
# show_default=True,
# help="Which pattern signs to compute: both, pos, or neg.",
# )
@click.option("--stranded", is_flag=True, help="Treat input as stranded so do not add reverse-complement.")
@click.option(
"--pattern-type",
type=click.Choice(["both", "pos", "neg"]),
default="both",
show_default=True,
help="Which pattern signs to compute: both, pos, or neg.",
)
@click.option(
"--meme-motif-db", type=str, default="hocomoco_v13", show_default=True, help="Path to the MEME motif database."
)
Expand All @@ -327,6 +364,13 @@ def cli_modisco_reports(
@click.option("--trim-threshold", type=float, default=0.3, show_default=True, help="Trim threshold.")
@click.option("--trim-min-length", type=int, default=3, show_default=True, help="Trim minimum length.")
@click.option("--tomtomlite", type=bool, default=False, show_default=True, help="Whether to use TomtomLite.")
@click.option(
"--seqlet-motif-trim-threshold",
type=float,
default=0.2,
show_default=True,
help="Trim threshold for motifs in seqlets bed file.",
)
def cli_modisco(
output_prefix: str,
tasks: Optional[List[str]] = None,
Expand All @@ -336,8 +380,9 @@ def cli_modisco(
metadata: Optional[str] = None,
method: str = "saliency",
batch_size: int = 4,
genes: Optional[List[str]] = None, # TODO: list of genes
genes: Optional[str] = None,
top_n_markers: Optional[int] = None,
correct_grad: bool = True,
device: Optional[str] = None,
num_workers: int = 4,
genome: str = "hg38",
Expand All @@ -349,8 +394,8 @@ def cli_modisco(
flank_size: int = 5,
initial_flank_to_add: int = 10,
final_flank_to_add: int = 0,
# stranded: bool = False,
# pattern_type: str = "both",
stranded: bool = False,
pattern_type: str = "both",
# reports parameters
meme_motif_db: str = "hocomoco_v13",
img_path_suffix: str = "",
Expand All @@ -359,6 +404,8 @@ def cli_modisco(
trim_threshold: float = 0.3,
trim_min_length: int = 3,
tomtomlite: bool = False,
# seqlet thresholds
seqlet_motif_trim_threshold: float = 0.2,
):
if model in ["0", "1", "2", "3"]:
model = int(model)
Expand All @@ -380,6 +427,7 @@ def cli_modisco(
batch_size=batch_size,
genes=genes,
top_n_markers=top_n_markers,
correct_grad=correct_grad,
device=device,
num_workers=num_workers,
genome=genome,
Expand All @@ -391,8 +439,8 @@ def cli_modisco(
flank_size=flank_size,
initial_flank_to_add=initial_flank_to_add,
final_flank_to_add=final_flank_to_add,
# stranded=stranded,
# pattern_type=pattern_type,
stranded=stranded,
pattern_type=pattern_type,
# reports parameters
img_path_suffix=img_path_suffix,
meme_motif_db=meme_motif_db,
Expand All @@ -401,4 +449,6 @@ def cli_modisco(
trim_threshold=trim_threshold,
trim_min_length=trim_min_length,
tomtomlite=tomtomlite,
# seqlet thresholds
seqlet_motif_trim_threshold=seqlet_motif_trim_threshold,
)
Loading