Skip to content

Commit 6a18905

Browse files
MuhammedHasanMuhammed Hasan Celik
andauthored
bug fix for no overlap and argument fix (#51)
* bug fix for no overlap and argument fix * bug fix for passing metadata --------- Co-authored-by: Muhammed Hasan Celik <celik.muhammed_hasan@gene.com>
1 parent fff776d commit 6a18905

File tree

5 files changed

+24
-6
lines changed

5 files changed

+24
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,3 +189,4 @@ docs/wandb
189189
run_tutorial.s*
190190

191191
logs/
192+
*.sbatch

src/decima/data/dataset.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,8 @@ def __len__(self):
868868
def validate_allele_seq(self, gene, variant):
869869
seq = self.result.gene_sequence(gene, genome=self.genome)
870870
pos = variant.rel_pos
871+
if variant.strand == "-":
872+
pos = pos - len(variant.ref) + 1
871873
ref_match = seq[pos : pos + len(variant.ref)] == variant.ref_tx
872874
alt_match = seq[pos : pos + len(variant.alt)] == variant.alt_tx
873875
return ref_match, alt_match
@@ -889,6 +891,9 @@ def __getitem__(self, idx):
889891
variant = self.variants.iloc[seq_idx]
890892
rel_pos = variant.rel_pos + self.max_seq_shift
891893

894+
if variant.strand == "-":
895+
rel_pos = rel_pos - len(variant.ref) + 1
896+
892897
# by default cache values are nan if matched with reference genome
893898
# then it will be replaced with the predicted expression from cache.
894899
pred_expr = {model_name: torch.full((self.result.shape[0],), torch.nan) for model_name in self.model_names}

src/decima/utils/dataframe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
5050
self.writer.close()
5151
else:
5252
warnings.warn("NoDataFrameWrittenError: No dataframe was written to the parquet file.")
53+
pd.DataFrame({}).to_parquet(self.output_path)
5354
self.first_chunk = True
5455

5556
def write(self, chunk: pd.DataFrame) -> None:

src/decima/vep/attributions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from decima.utils.io import read_vcf_chunks, VariantAttributionWriter
3232
from decima.core.result import DecimaResult
3333
from decima.data.dataset import VariantDataset
34+
from decima.hub import load_decima_model
3435
from decima.interpret.attributer import DecimaAttributer
3536
from decima.model.metrics import WarningCounter
3637
from decima.vep.vep import _log_vep_warnings, _write_vep_warnings
@@ -158,23 +159,22 @@ def variant_effect_attribution(
158159
f"Unsupported input type: {type(variants)}. Must be pd.DataFrame or str (path to .tsv or .vcf)."
159160
)
160161

161-
result = DecimaResult.load(metadata_anndata)
162-
162+
model = load_decima_model(model, device=device)
163+
result = DecimaResult.load(metadata_anndata or model.name)
163164
tasks, off_tasks = _get_on_off_tasks(result, tasks, off_tasks)
164-
attributer = DecimaAttributer.load_decima_attributer(
165-
model_name=model,
165+
attributer = DecimaAttributer(
166+
model=model,
166167
tasks=tasks,
167168
off_tasks=off_tasks,
168169
method=method,
169170
transform=transform,
170-
device=device,
171171
)
172172

173173
warning_counter = WarningCounter()
174174

175175
dataset = VariantDataset(
176176
variants,
177-
metadata_anndata=metadata_anndata,
177+
metadata_anndata=result,
178178
gene_col=gene_col,
179179
distance_type=distance_type,
180180
min_distance=min_distance,

tests/test_vep.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,17 @@ def test_VariantDataset_overlap_genes(df_variant):
8282
})
8383
df = VariantDataset.overlap_genes(df_variant, df_genes)
8484

85+
def test_VariantDataset_validate_allele_seq():
86+
df_variant = pd.DataFrame({
87+
"chrom": ["chr15"],
88+
"pos": [44715509],
89+
"ref": ["CC"],
90+
"alt": ["TT"]
91+
})
92+
dataset = VariantDataset(df_variant)
93+
ref_match, _ = dataset.validate_allele_seq("SPG11", dataset.variants.iloc[1])
94+
assert ref_match
95+
8596
def test_VariantDataset(df_variant):
8697

8798
dataset = VariantDataset(df_variant, model_name="v1_rep0")

0 commit comments

Comments
 (0)