Skip to content

Commit 4fff6ee

Browse files
bug fix training and gene mask lightning model (#21)
1 parent f1fc2f1 commit 4fff6ee

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

src/decima/data/dataset.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import pandas as pd
77
from more_itertools import flatten
88
from torch.utils.data import Dataset, default_collate
9-
from grelu.sequence.format import indices_to_strings
9+
from grelu.sequence.format import indices_to_strings, indices_to_one_hot
1010
from grelu.data.augment import Augmenter, _split_overall_idx
1111
from grelu.sequence.utils import reverse_complement
1212

1313
from decima.constants import DECIMA_CONTEXT_SIZE, ENSEMBLE_MODELS_NAMES
14-
from decima.data.read_hdf5 import _extract_center
14+
from decima.data.read_hdf5 import _extract_center, index_genes
1515
from decima.core.result import DecimaResult
1616

1717
from decima.model.metrics import WarningType

src/decima/model/lightning.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,3 +614,29 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0) -> Union[dict, Tensor
614614
return {"expression": expression, "warnings": batch["warning"]}
615615
else:
616616
return self(batch)
617+
618+
619+
class GeneMaskLightningModel(LightningModel):
620+
def __init__(
621+
self,
622+
gene_mask_start,
623+
gene_mask_end,
624+
model_params: dict,
625+
train_params: dict = {},
626+
data_params: dict = {},
627+
name: str = "",
628+
):
629+
super().__init__(
630+
model_params=model_params,
631+
train_params=train_params,
632+
data_params=data_params,
633+
name=name,
634+
)
635+
self.gene_mask_start = gene_mask_start
636+
self.gene_mask_end = gene_mask_end
637+
638+
def forward(self, x: Union[Tuple[Tensor, Tensor], Tensor, str, List[str]], logits: bool = False):
639+
mask = torch.zeros((x.shape[0], 1, x.shape[2]), device=x.device, dtype=x.dtype)
640+
mask[:, :, self.gene_mask_start : self.gene_mask_end] = 1
641+
x = torch.cat([x, mask], dim=1)
642+
return super().forward(x, logits)

tests/test_lightning.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from decima.constants import DECIMA_CONTEXT_SIZE, NUM_CELLS
44
from decima.data.dataset import VariantDataset
5-
from decima.model.lightning import LightningModel
5+
from decima.model.lightning import LightningModel, GeneMaskLightningModel
66
from decima.model.metrics import WarningType
77

88
from conftest import device
@@ -54,3 +54,14 @@ def test_LightningModel_predict_on_dataset_ensemble(lightning_model, df_variant)
5454
assert results["expression"].shape == (82, NUM_CELLS)
5555
assert results["warnings"]['unknown'] == 0
5656
assert results["warnings"]['allele_mismatch_with_reference_genome'] == 13
57+
58+
59+
@pytest.mark.long_running
60+
def test_GeneMaskLightningModel_forward():
61+
seq = torch.randn(1, 4, DECIMA_CONTEXT_SIZE).to(device)
62+
model = GeneMaskLightningModel(
63+
gene_mask_start=200_000, gene_mask_end=300_000,
64+
model_params={"n_tasks": NUM_CELLS, "init_borzoi": False}, name="v1_rep0"
65+
).to(device)
66+
preds = model(seq)
67+
assert preds.shape == (1, NUM_CELLS, 1)

0 commit comments

Comments
 (0)