Skip to content

Commit 21227de

Browse files
authored
Merge pull request #95 from Genentech/filter-variants-chrom
Prevent padded intervals from exceeding chromosome ends
2 parents f9d07d2 + 3f2a45f commit 21227de

File tree

3 files changed

+76
-11
lines changed

3 files changed

+76
-11
lines changed

src/grelu/data/dataset.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.utils.data import Dataset
1717

1818
from grelu.data.augment import Augmenter, _split_overall_idx
19+
from grelu.data.preprocess import check_chrom_ends
1920
from grelu.data.utils import _check_multiclass, _create_task_data
2021
from grelu.sequence.format import (
2122
INDEX_TO_BASE_HASH,
@@ -152,6 +153,7 @@ def _load_seqs(self, seqs: Union[str, Sequence, pd.DataFrame, np.ndarray]) -> No
152153
seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end)
153154

154155
if get_input_type(seqs) == "intervals":
156+
check_chrom_ends(seqs, genome=self.genome)
155157
self.intervals = seqs
156158
self.chroms = list(set(self.intervals.chrom))
157159
else:
@@ -604,6 +606,7 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
604606

605607
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
606608
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
609+
check_chrom_ends(self.intervals, genome=self.genome)
607610
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
608611

609612
def __len__(self) -> int:
@@ -711,6 +714,7 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
711714

712715
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
713716
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
717+
check_chrom_ends(self.intervals, genome=self.genome)
714718
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
715719
self.n_seqs = self.seqs.shape[0]
716720

src/grelu/data/preprocess.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import subprocess
7+
import warnings
78
from typing import Callable, List, Optional, Union
89

910
import bioframe as bf
@@ -373,6 +374,51 @@ def filter_blacklist(
373374
)
374375

375376

377+
def check_chrom_ends(
378+
data: Union[pd.DataFrame, AnnData],
379+
genome: Optional[str] = None,
380+
):
381+
"""
382+
Check that intervals do not exceed the ends of the chromosome.
383+
384+
Args:
385+
data: Either a pandas dataframe of genomic intervals or an Anndata
386+
object with intervals in .var
387+
genome: name of the genome corresponding to intervals
388+
389+
Raises:
390+
ValueError if any interval exceeds the chtomosome ends
391+
"""
392+
from grelu.io.genome import read_sizes
393+
394+
# Get genomic intervals
395+
if isinstance(data, AnnData):
396+
intervals = data.var
397+
elif isinstance(data, pd.DataFrame):
398+
intervals = data
399+
400+
# Check start
401+
fail = intervals[intervals.start < 0].index
402+
403+
# Filter end if the genome is provided
404+
if genome is None:
405+
warnings.warn(
406+
"No genome is provided; only intervals with negative start values will be flagged."
407+
)
408+
else:
409+
sizes = read_sizes(genome)
410+
for chrom, size in sizes.values:
411+
fail = fail.append(
412+
intervals[(intervals.chrom == chrom) & (intervals.end > size)].index
413+
)
414+
415+
fail = np.unique(fail)
416+
if len(fail) > 0:
417+
raise ValueError(
418+
f"Indices of intervals that extend beyond the chromosome ends: {','.join(fail.astype(str))}."
419+
)
420+
421+
376422
def filter_chrom_ends(
377423
data: Union[pd.DataFrame, AnnData],
378424
genome: Optional[str] = None,

tests/test_preprocess.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from grelu.data.preprocess import (
7+
check_chrom_ends,
78
filter_blacklist,
89
filter_cells,
910
filter_chrom_ends,
@@ -117,22 +118,36 @@ def test_filter_blacklist():
117118
assert filter_blacklist(intervals, genome="hg38").equals(intervals.iloc[-2:, :])
118119

119120

121+
chrom_end_intervals = pd.DataFrame(
122+
{
123+
"chrom": ["chr1", "chr1", "chr1", "chr1", "chr1"],
124+
"start": [-10, 10, 1000, 248956300, 248956350],
125+
"end": [90, 110, 1100, 248956400, 248956450],
126+
}
127+
)
128+
129+
120130
def test_filter_chrom_ends():
121-
intervals = pd.DataFrame(
122-
{
123-
"chrom": ["chr1", "chr1", "chr1", "chr1", "chr1"],
124-
"start": [-10, 10, 1000, 248956300, 248956350],
125-
"end": [90, 110, 1100, 248956400, 248956450],
126-
}
127-
)
128-
assert filter_chrom_ends(intervals, genome="hg38").equals(
129-
intervals.iloc[[1, 2, 3], :]
131+
132+
assert filter_chrom_ends(chrom_end_intervals, genome="hg38").equals(
133+
chrom_end_intervals.iloc[[1, 2, 3], :]
130134
)
131-
assert filter_chrom_ends(intervals, genome="hg38", pad=100).equals(
132-
intervals.iloc[[2], :]
135+
assert filter_chrom_ends(chrom_end_intervals, genome="hg38", pad=100).equals(
136+
chrom_end_intervals.iloc[[2], :]
133137
)
134138

135139

140+
def test_check_chrom_ends():
141+
with pytest.raises(Exception) as e_info:
142+
check_chrom_ends(chrom_end_intervals, genome="hg38")
143+
assert (
144+
str(e_info.value)
145+
== "Indices of intervals that extend beyond the chromosome ends: 0,4."
146+
)
147+
148+
check_chrom_ends(chrom_end_intervals.iloc[1:2], genome="hg38")
149+
150+
136151
def test_merge_intervals_by_column():
137152
intervals = pd.DataFrame(
138153
{

0 commit comments

Comments
 (0)