Skip to content

Commit 7628d31

Browse files
committed
added checks for variant intervals exceeding chromosome ends
1 parent c27fb34 commit 7628d31

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

src/grelu/data/dataset.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from torch.utils.data import Dataset
1717

1818
from grelu.data.augment import Augmenter, _split_overall_idx
19+
from grelu.data.preprocess import filter_chrom_ends
1920
from grelu.data.utils import _check_multiclass, _create_task_data
2021
from grelu.sequence.format import (
2122
INDEX_TO_BASE_HASH,
@@ -152,6 +153,7 @@ def _load_seqs(self, seqs: Union[str, Sequence, pd.DataFrame, np.ndarray]) -> No
152153
seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end)
153154

154155
if get_input_type(seqs) == "intervals":
156+
seqs = filter_chrom_ends(seqs, genome=self.genome)
155157
self.intervals = seqs
156158
self.chroms = list(set(self.intervals.chrom))
157159
else:
@@ -603,7 +605,8 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
603605
from grelu.variant import variants_to_intervals
604606

605607
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
606-
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
608+
intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
609+
self.intervals = filter_chrom_ends(intervals, genome=self.genome)
607610
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
608611

609612
def __len__(self) -> int:
@@ -710,7 +713,8 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
710713
from grelu.variant import variants_to_intervals
711714

712715
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
713-
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
716+
intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
717+
self.intervals = filter_chrom_ends(intervals, genome=self.genome)
714718
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
715719
self.n_seqs = self.seqs.shape[0]
716720

src/grelu/data/preprocess.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,15 @@ def filter_intervals(
3333
Returns:
3434
Filtered intervals in the same format (if inplace = False)
3535
"""
36-
print("Keeping {} intervals".format(sum(keep)))
37-
if isinstance(data, pd.DataFrame):
38-
return data.drop(index=data.index[~keep], inplace=inplace)
39-
elif isinstance(data, AnnData):
40-
if inplace:
41-
data._inplace_subset_var(index=data.var_names[keep])
42-
else:
43-
return data[:, keep]
36+
if sum(keep) < data.shape[0]:
37+
print("Keeping {} of {} intervals".format(sum(keep), data.shape[0]))
38+
if isinstance(data, pd.DataFrame):
39+
return data.drop(index=data.index[~keep], inplace=inplace)
40+
elif isinstance(data, AnnData):
41+
if inplace:
42+
data._inplace_subset_var(index=data.var_names[keep])
43+
else:
44+
return data[:, keep]
4445

4546

4647
def filter_obs(

0 commit comments

Comments
 (0)