Skip to content

Commit 6699719

Browse files
author
Muhammed Hasan Celik
committed
code review
2 parents 4aa09b9 + 87037a0 commit 6699719

File tree

7 files changed

+21
-26
lines changed

7 files changed

+21
-26
lines changed

src/decima/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import sys
2+
from decima.core.result import DecimaResult
23
from decima.interpret.save_attributions import predict_save_attributions
4+
from decima.vep import predict_variant_effect
35

46

57
if sys.version_info[:2] >= (3, 8):
@@ -16,3 +18,6 @@
1618
__version__ = "unknown"
1719
finally:
1820
del version, PackageNotFoundError
21+
22+
23+
__all__ = ["DecimaResult", "predict_variant_effect", "predict_save_attributions"]

src/decima/core/metadata.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ class GeneMetadata:
1010
Attributes:
1111
name: Gene name
1212
chrom: Chromosome where the gene is located
13-
start: Start position in the chromosome
14-
end: End position in the chromosome
13+
start: Start position of the region around the gene to perform predictions in the chromosome
14+
end: End position of the region around the gene to perform predictions in the chromosome
1515
strand: Strand orientation (+ or -)
1616
gene_type: Type of gene (e.g., protein_coding)
1717
frac_nan: Fraction of NaN values
@@ -91,8 +91,8 @@ class CellMetadata:
9191
disease: str
9292
study: str
9393
dataset: str
94-
region: str
95-
subregion: str
94+
region: Optional[str]
95+
subregion: Optional[str]
9696
celltype_coarse: Optional[str]
9797
n_cells: int
9898
total_counts: float

src/decima/data/dataset.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,6 @@ def __getitem__(self, idx):
328328
variant = self.variants.iloc[seq_idx]
329329

330330
warnings = list()
331-
if not self.validate_allele_seq(variant.gene, variant):
332-
warnings.append(WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME)
333-
334331
if allele_idx:
335332
seq, mask = self.result.prepare_one_hot(
336333
variant.gene,
@@ -339,6 +336,9 @@ def __getitem__(self, idx):
339336
allele = seq[:, variant.rel_pos : variant.rel_pos + len(variant.alt)]
340337
allele_tx = variant.alt_tx
341338
else:
339+
if not self.validate_allele_seq(variant.gene, variant):
340+
warnings.append(WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME)
341+
342342
seq, mask = self.result.prepare_one_hot(
343343
variant.gene,
344344
variants=[{"chrom": variant.chrom, "pos": variant.pos, "ref": variant.alt, "alt": variant.ref}],

src/decima/interpret/attributions.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,15 +400,14 @@ def scan_motifs(self, motifs: str = "hocomoco_v12", window: int = 18, pthresh: f
400400
return df.sort_values("p-value")
401401

402402
def plot_peaks(self, overlapping_min_dist=1000, figsize=(10, 2)):
403-
"""Plot attribution scores in a window around a relative location.
403+
"""Plot attribution scores and highlight peaks.
404404
405405
Args:
406-
relative_loc: Position relative to TSS to center plot on
407-
window: Number of bases to show on each side of center
406+
overlapping_min_dist: Minimum distance between peaks to consider them overlapping
408407
figsize: Figure size in inches (width, height)
409408
410409
Returns:
411-
matplotlib.pyplot.Figure: Attribution plot showing the window around the specified location
410+
plotnine.ggplot: The plotted figure showing attribution scores with highlighted peaks
412411
"""
413412
return plot_peaks(
414413
self.attrs,

src/decima/model/lightning.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from .decima_model import DecimaModel
2121
from .loss import TaskWisePoissonMultinomialLoss
22-
from .metrics import DiseaseLfcMSE, WarningCounter, WarningType
22+
from .metrics import DiseaseLfcMSE, WarningCounter
23+
2324

2425
default_train_params = {
2526
"lr": 4e-5,
@@ -448,11 +449,7 @@ def predict_on_dataset(
448449

449450
expression = np.mean(expression, axis=1) # B T
450451

451-
num_warnings = self.warning_counter.compute()
452-
# allele mismatch is counted twice for each variant due to the two alleles
453-
num_warnings[WarningType.ALLELE_MISMATCH_WITH_REFERENCE_GENOME.value] //= 2
454-
455-
return {"expression": expression, "warnings": num_warnings}
452+
return {"expression": expression, "warnings": self.warning_counter.compute()}
456453

457454
def get_task_idxs(
458455
self,

src/decima/plot/visualize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def plot_peaks(attrs, tss_pos, df_peaks=None, overlapping_min_dist=1000, figsize
197197
198198
Args:
199199
attr: Attribution scores array
200-
tss_pos: Position of TSS
200+
tss_pos: Position of TSS (relative to the gene)
201201
df_peaks: DataFrame containing peak information
202202
overlapping_min_dist: Minimum distance between peaks to consider them overlapping
203203
figsize: Figure size

src/decima/utils/inject.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,16 @@ def inject(self, variant: Dict):
8080
variant_start = variant["pos"]
8181
variant_end = variant_start + len(variant["ref"])
8282

83-
if variant_end <= self.start:
83+
if variant_start < self.start:
8484
warnings.warn(
8585
f"Variant position `{variant['pos']}` is upstream of the interval `[{self.start}, {self.end}]`. Skipping..."
8686
)
8787
return self
88-
elif variant_start < self.start < variant_end:
89-
_, right_variant = self._split_variant(variant, self.start)
90-
return self.inject(right_variant)
91-
elif self.end < variant_start:
88+
elif self.end < variant_end:
9289
warnings.warn(
9390
f"Variant position `{variant['pos']}` is downstream of the interval `[{self.start}, {self.end}]`. Skipping..."
9491
)
9592
return self
96-
elif variant_start < self.end < variant_end:
97-
left_variant, _ = self._split_variant(variant, self.end)
98-
return self.inject(left_variant)
9993

10094
if variant_start < self.anchor < variant_end:
10195
left_variant, right_variant = self._split_variant(variant, self.anchor)

0 commit comments

Comments
 (0)