Skip to content

Commit f1fc2f1

Browse files
MuhammedHasanMuhammed Hasan Celik
andauthored
0_3_0: cache prediction for the reference sequence to speed up predictions (#20)
* ensemble vep init * backward compability of grelu, ensembling, testcases, custom fasta * gene dataset * gene expression prediction and sequence shifting * fix testcase * conflig * cached reference prediction * dtype reference caching safetensors * bug fix for h5 datalaoder --------- Co-authored-by: Muhammed Hasan Celik <[email protected]>
1 parent f739160 commit f1fc2f1

File tree

17 files changed

+345
-101
lines changed

17 files changed

+345
-101
lines changed

src/decima/cli/predict_genes.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,24 @@
3838
is_flag=True,
3939
help="Save the replicates in the output parquet file. Default: False.",
4040
)
41+
@click.option(
42+
"--float-precision",
43+
type=str,
44+
default="32",
45+
help="Floating-point precision to be used in calculations. Avaliable options include: '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true', '32', '16', and 'bf16'.",
46+
)
4147
def cli_predict_genes(
42-
output, genes, model, metadata, device, batch_size, num_workers, max_seq_shift, genome, save_replicates
48+
output,
49+
genes,
50+
model,
51+
metadata,
52+
device,
53+
batch_size,
54+
num_workers,
55+
max_seq_shift,
56+
genome,
57+
save_replicates,
58+
float_precision,
4359
):
4460
if model in ["0", "1", "2", "3"]:
4561
model = int(model)
@@ -63,5 +79,6 @@ def cli_predict_genes(
6379
max_seq_shift=max_seq_shift,
6480
genome=genome,
6581
save_replicates=save_replicates,
82+
float_precision=float_precision,
6683
)
6784
ad.write_h5ad(output)

src/decima/cli/vep.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@
6868
is_flag=True,
6969
help="Save the replicates in the output parquet file. Default: False.",
7070
)
71+
@click.option(
72+
"--disable-reference-cache",
73+
is_flag=True,
74+
help="Disables the reference cache which significantly speeds up the computation by caching the reference expression predictios in the metadata.",
75+
)
76+
@click.option(
77+
"--float-precision",
78+
type=str,
79+
default="32",
80+
help="Floating-point precision to be used in calculations. Avaliable options include: '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true', '32', '16', and 'bf16'.",
81+
)
7182
def cli_predict_variant_effect(
7283
variants,
7384
output_pq,
@@ -85,6 +96,8 @@ def cli_predict_variant_effect(
8596
gene_col,
8697
genome,
8798
save_replicates,
99+
disable_reference_cache,
100+
float_precision,
88101
):
89102
"""Predict variant effect and save to parquet
90103
@@ -108,6 +121,8 @@ def cli_predict_variant_effect(
108121
109122
>>> decima vep -v "data/sample.vcf" -o "vep_results.parquet" --genome "path/to/fasta/hg38.fa" # use custom genome build
110123
"""
124+
reference_cache = not disable_reference_cache
125+
111126
if model in ["0", "1", "2", "3"]: # replicate index
112127
model = int(model)
113128

@@ -137,6 +152,8 @@ def cli_predict_variant_effect(
137152
gene_col=gene_col,
138153
genome=genome,
139154
save_replicates=save_replicates,
155+
reference_cache=reference_cache,
156+
float_precision=float_precision,
140157
)
141158

142159

src/decima/constants.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
import os
2+
3+
14
DECIMA_CONTEXT_SIZE = 524288
25
SUPPORTED_GENOMES = {"hg38"}
36
NUM_CELLS = 8856
7+
8+
if "DECIMA_ENSEMBLE_MODELS_NAMES" in os.environ:
9+
ENSEMBLE_MODELS_NAMES = os.environ["DECIMA_ENSEMBLE_MODELS_NAMES"].split(",")
10+
else:
11+
ENSEMBLE_MODELS_NAMES = ["v1_rep0", "v1_rep1", "v1_rep2", "v1_rep3"]

src/decima/core/metadata.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ class GeneMetadata:
5050
gene_id: str
5151
pearson: float
5252
size_factor_pearson: float
53+
ensembl_canonical_tss: Optional[bool]
5354

5455
@classmethod
5556
def from_series(cls, name: str, series: pd.Series) -> "GeneMetadata":

src/decima/core/result.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import torch
55
import pandas as pd
6+
67
from grelu.sequence.format import intervals_to_strings, strings_to_one_hot
78

89
from decima.constants import DECIMA_CONTEXT_SIZE
@@ -143,7 +144,9 @@ def get_cell_metadata(self, cell: str) -> CellMetadata:
143144
raise KeyError(f"Cell {cell} not found in dataset. See avaliable cells with `result.cells`.")
144145
return CellMetadata.from_series(cell, self.cell_metadata.loc[cell])
145146

146-
def predicted_expression_matrix(self, genes: Optional[List[str]] = None) -> pd.DataFrame:
147+
def predicted_expression_matrix(
148+
self, genes: Optional[List[str]] = None, model_name: Optional[str] = None
149+
) -> pd.DataFrame:
147150
"""Get predicted expression matrix for all or specific genes.
148151
149152
Args:
@@ -152,10 +155,14 @@ def predicted_expression_matrix(self, genes: Optional[List[str]] = None) -> pd.D
152155
Returns:
153156
pd.DataFrame: Predicted expression matrix (cells x genes)
154157
"""
158+
model_name = "preds" if (model_name is None) or (model_name == "ensemble") else model_name
155159
if genes is None:
156-
return pd.DataFrame(self.anndata.layers["preds"], index=self.cells, columns=self.genes)
160+
return pd.DataFrame(self.anndata.layers[model_name], index=self.cells, columns=self.genes)
157161
else:
158-
return pd.DataFrame(self.anndata[:, genes].layers["preds"], index=self.cells, columns=genes)
162+
return pd.DataFrame(self.anndata[:, genes].layers[model_name], index=self.cells, columns=genes)
163+
164+
def predicted_gene_expression(self, gene, model_name):
165+
return torch.from_numpy(self.anndata[:, gene].layers[model_name].ravel())
159166

160167
def _pad_gene_metadata(self, gene_meta: pd.Series, padding: int = 0) -> pd.Series:
161168
"""
@@ -184,6 +191,7 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padd
184191
Returns:
185192
torch.Tensor: One-hot encoding of the gene
186193
"""
194+
187195
assert gene in self.genes, f"{gene} is not in the anndata object"
188196
gene_meta = self._pad_gene_metadata(self.gene_metadata.loc[gene], padding)
189197

@@ -201,6 +209,7 @@ def prepare_one_hot(self, gene: str, variants: Optional[List[Dict]] = None, padd
201209

202210
def gene_sequence(self, gene: str, stranded: bool = True) -> str:
203211
"""Get sequence for a gene."""
212+
204213
try:
205214
assert gene in self.genes, f"{gene} is not in the anndata object"
206215
except AssertionError:

src/decima/data/dataset.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
import warnings
22
import torch
33
import h5py
4-
import numpy as np
54
import bioframe
5+
import numpy as np
6+
import pandas as pd
67
from more_itertools import flatten
78
from torch.utils.data import Dataset, default_collate
89
from grelu.sequence.format import indices_to_strings
910
from grelu.data.augment import Augmenter, _split_overall_idx
1011
from grelu.sequence.utils import reverse_complement
1112

12-
from decima.constants import DECIMA_CONTEXT_SIZE
13-
from decima.data.read_hdf5 import _extract_center, index_genes, indices_to_one_hot
13+
from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES
14+
from decima.data.read_hdf5 import _extract_center
1415
from decima.core.result import DecimaResult
1516

1617
from decima.model.metrics import WarningType
@@ -221,9 +222,12 @@ def __init__(
221222
distance_type="tss",
222223
min_distance=0,
223224
max_distance=float("inf"),
225+
model_name=None,
226+
reference_cache=True,
224227
):
225228
super().__init__()
226229

230+
self.reference_cache = reference_cache
227231
self.result = DecimaResult.load(metadata_anndata)
228232

229233
self.variants = self._overlap_genes(
@@ -253,6 +257,19 @@ def __init__(
253257
self.n_augmented = len(self.augmenter)
254258
self.padded_seq_len = DECIMA_CONTEXT_SIZE + (2 * self.max_seq_shift)
255259

260+
if (model_name is None) or (not reference_cache):
261+
self.model_names = list() # no reference caching
262+
elif model_name == "ensemble":
263+
self.model_names = ENSEMBLE_MODELS_NAMES
264+
else:
265+
self.model_names = [model_name]
266+
267+
for model_name in self.model_names:
268+
assert model_name in self.result.anndata.layers.keys(), (
269+
f"Model {model_name} not found in the metadata annotation. "
270+
"You may not using the correct metadata file for this model."
271+
)
272+
256273
@staticmethod
257274
def overlap_genes(
258275
df_variants,
@@ -372,16 +389,25 @@ def __len__(self):
372389

373390
def validate_allele_seq(self, gene, variant):
374391
seq = self.result.gene_sequence(gene)
375-
vstart = variant.rel_pos
376-
vend = vstart + len(variant.ref)
377-
return (seq[vstart:vend] == variant.ref_tx) or (seq[vstart:vend] == variant.alt_tx)
392+
pos = variant.rel_pos
393+
ref_match = seq[pos : pos + len(variant.ref)] == variant.ref_tx
394+
alt_match = seq[pos : pos + len(variant.alt)] == variant.alt_tx
395+
return ref_match, alt_match
396+
397+
def predicted_expression_cache(self, gene):
398+
return {model_name: self.result.predicted_gene_expression(gene, model_name) for model_name in self.model_names}
378399

379400
def __getitem__(self, idx):
380401
seq_idx, augment_idx, allele_idx = _split_overall_idx(idx, (self.n_seqs, self.n_augmented, self.n_alleles))
381402

382403
variant = self.variants.iloc[seq_idx]
383404
rel_pos = variant.rel_pos + self.max_seq_shift
384405

406+
# by default cache values are nan if matched with reference genome
407+
# then it will be replaced with the predicted expression from cache.
408+
pred_expr = {model_name: torch.full((self.result.shape[0],), torch.nan) for model_name in self.model_names}
409+
ref_match, alt_match = self.validate_allele_seq(variant.gene, variant)
410+
385411
warnings = list()
386412
if allele_idx:
387413
seq, mask = self.result.prepare_one_hot(
@@ -391,8 +417,11 @@ def __getitem__(self, idx):
391417
)
392418
allele = seq[:, rel_pos : rel_pos + len(variant.alt)]
393419
allele_tx = variant.alt_tx
420+
421+
if alt_match:
422+
pred_expr = self.predicted_expression_cache(variant.gene)
394423
else:
395-
if not self.validate_allele_seq(variant.gene, variant):
424+
if (not ref_match) and (not alt_match):
396425
warnings.append(WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME)
397426

398427
seq, mask = self.result.prepare_one_hot(
@@ -403,23 +432,33 @@ def __getitem__(self, idx):
403432
allele = seq[:, rel_pos : rel_pos + len(variant.ref)]
404433
allele_tx = variant.ref_tx
405434

406-
if len(variant.ref_tx) == len(variant.alt_tx): # not SNV there would be shifts
435+
if ref_match:
436+
pred_expr = self.predicted_expression_cache(variant.gene)
437+
438+
if len(variant.ref) == len(variant.alt): # not SNV there would be shifts
407439
assert indices_to_strings(allele.argmax(axis=0)) == allele_tx
408440

409441
inputs = torch.vstack([seq, mask])
410-
411442
inputs = _extract_center(inputs, seq_len=self.padded_seq_len)
412443
inputs = self.augmenter(seq=inputs, idx=augment_idx)
413-
return {
444+
445+
data = {
414446
"seq": inputs,
415447
"warning": warnings,
416448
}
449+
if len(self.model_names) > 0:
450+
data["pred_expr"] = pred_expr
451+
452+
return data
417453

418454
def collate_fn(self, batch):
419-
return {
455+
_batch = {
420456
"seq": default_collate([i["seq"] for i in batch]),
421457
"warning": list(flatten([b["warning"] for b in batch])),
422458
}
459+
if "pred_expr" in batch[0]:
460+
_batch["pred_expr"] = default_collate([b["pred_expr"] for b in batch])
461+
return _batch
423462

424463
def __str__(self):
425464
return (

src/decima/hub/__init__.py

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,6 @@ def login_wandb():
1515
wandb.login(host=os.environ.get("WANDB_HOST", DEFAULT_WANDB_HOST), relogin=True, anonymous="must", timeout=0)
1616

1717

18-
def get_model_name(model: Union[str, int] = 0) -> str:
19-
if isinstance(model, int):
20-
return f"decima_rep{model}"
21-
elif isinstance(model, str):
22-
return model
23-
else:
24-
raise ValueError(
25-
f"Invalid model: {model} it need to be a string of model_name on wandb or an integer of replicate number {0, 1, 2, 3}"
26-
)
27-
28-
2918
def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None):
3019
"""Load a pre-trained Decima model from wandb or local path.
3120
@@ -45,39 +34,32 @@ def load_decima_model(model: Union[str, int] = 0, device: Optional[str] = None):
4534
if isinstance(model, LightningModel):
4635
return model
4736
elif model == "ensemble":
48-
model = EnsembleLightningModel(
37+
return EnsembleLightningModel(
4938
[
5039
load_decima_model(0, device),
5140
load_decima_model(1, device),
5241
load_decima_model(2, device),
5342
load_decima_model(3, device),
5443
]
5544
)
56-
model.name = "ensemble"
57-
return model
5845
elif isinstance(model, str):
59-
model_name = get_model_name(model)
6046
if Path(model).exists():
61-
model = LightningModel.load_from_checkpoint(model, map_location=device)
62-
model.name = model_name
63-
return model
64-
elif isinstance(model, int):
65-
model_name = get_model_name(model)
47+
return LightningModel.load_safetensor(model, device=device)
48+
elif model in {0, 1, 2, 3}:
49+
model_name = f"rep{model}"
6650
else:
6751
raise ValueError(
6852
f"Invalid model: {model} it need to be a string of model_name on wandb "
6953
"or an integer of replicate number {0, 1, 2, 3}, or a path to a local model"
7054
)
7155

7256
if model_name.upper() in os.environ:
73-
return LightningModel.load_from_checkpoint(os.environ[model_name.upper()], map_location=device)
57+
return LightningModel.load_safetensor(os.environ[model_name.upper()], device=device)
7458

7559
art = get_artifact(model_name, project="decima")
7660
with TemporaryDirectory() as d:
7761
art.download(d)
78-
model = LightningModel.load_from_checkpoint(Path(d) / "model.ckpt", map_location=device)
79-
model.name = str(model_name)
80-
return model
62+
return LightningModel.load_safetensor(Path(d) / f"{model_name}.safetensors", device=device)
8163

8264

8365
def load_decima_metadata(path: Optional[str] = None):
@@ -95,7 +77,7 @@ def load_decima_metadata(path: Optional[str] = None):
9577
if "DECIMA_METADATA" in os.environ:
9678
return anndata.read_h5ad(os.environ["DECIMA_METADATA"])
9779

98-
art = get_artifact("decima_metadata", project="decima")
80+
art = get_artifact("metadata", project="decima")
9981
with TemporaryDirectory() as d:
10082
art.download(d)
10183
return anndata.read_h5ad(Path(d) / "metadata.h5ad")

src/decima/interpret/attributions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
import pyBigWig
77
from pyfaidx import Faidx
8-
import genomepy
98
from captum.attr import InputXGradient, Saliency, IntegratedGradients
109
from grelu.interpret.motifs import scan_sequences
1110
from grelu.sequence.format import convert_input_type, strings_to_one_hot
@@ -451,6 +450,8 @@ def save_bigwig(self, bigwig_path: str):
451450

452451
if self._chrom is not None:
453452
name = self.chrom
453+
import genomepy
454+
454455
sizes = genomepy.Genome("hg38").sizes
455456
bw.addHeader([(chrom, size) for chrom, size in sizes.items()])
456457
else:

0 commit comments

Comments
 (0)