Skip to content

Commit 44ad44c

Browse files
committed
added check_chrom_ends function and tests
1 parent 7628d31 commit 44ad44c

File tree

3 files changed

+89
-21
lines changed

3 files changed

+89
-21
lines changed

src/grelu/data/dataset.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from torch.utils.data import Dataset
1717

1818
from grelu.data.augment import Augmenter, _split_overall_idx
19-
from grelu.data.preprocess import filter_chrom_ends
19+
from grelu.data.preprocess import check_chrom_ends
2020
from grelu.data.utils import _check_multiclass, _create_task_data
2121
from grelu.sequence.format import (
2222
INDEX_TO_BASE_HASH,
@@ -153,7 +153,7 @@ def _load_seqs(self, seqs: Union[str, Sequence, pd.DataFrame, np.ndarray]) -> No
153153
seqs = resize(seqs, seq_len=self.padded_seq_len, end=self.end)
154154

155155
if get_input_type(seqs) == "intervals":
156-
seqs = filter_chrom_ends(seqs, genome=self.genome)
156+
check_chrom_ends(seqs, genome=self.genome)
157157
self.intervals = seqs
158158
self.chroms = list(set(self.intervals.chrom))
159159
else:
@@ -605,8 +605,8 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
605605
from grelu.variant import variants_to_intervals
606606

607607
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
608-
intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
609-
self.intervals = filter_chrom_ends(intervals, genome=self.genome)
608+
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
609+
check_chrom_ends(self.intervals, genome=self.genome)
610610
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
611611

612612
def __len__(self) -> int:
@@ -713,8 +713,8 @@ def _load_seqs(self, variants: pd.DataFrame) -> None:
713713
from grelu.variant import variants_to_intervals
714714

715715
self.padded_seq_len = self.seq_len + (2 * self.max_seq_shift)
716-
intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
717-
self.intervals = filter_chrom_ends(intervals, genome=self.genome)
716+
self.intervals = variants_to_intervals(variants, seq_len=self.padded_seq_len)
717+
check_chrom_ends(self.intervals, genome=self.genome)
718718
self.seqs = convert_input_type(self.intervals, "indices", genome=self.genome)
719719
self.n_seqs = self.seqs.shape[0]
720720

src/grelu/data/preprocess.py

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import os
66
import subprocess
7+
import warnings
78
from typing import Callable, List, Optional, Union
89

910
import bioframe as bf
@@ -33,15 +34,22 @@ def filter_intervals(
3334
Returns:
3435
Filtered intervals in the same format (if inplace = False)
3536
"""
36-
if sum(keep) < data.shape[0]:
37-
print("Keeping {} of {} intervals".format(sum(keep), data.shape[0]))
38-
if isinstance(data, pd.DataFrame):
37+
38+
if isinstance(data, pd.DataFrame):
39+
if sum(keep) < data.shape[0]:
40+
print("Keeping {} of {} intervals".format(sum(keep), data.shape[0]))
3941
return data.drop(index=data.index[~keep], inplace=inplace)
40-
elif isinstance(data, AnnData):
42+
else:
43+
return data
44+
elif isinstance(data, AnnData):
45+
if sum(keep) < data.shape[1]:
46+
print("Keeping {} of {} intervals".format(sum(keep), data.shape[1]))
4147
if inplace:
4248
data._inplace_subset_var(index=data.var_names[keep])
4349
else:
4450
return data[:, keep]
51+
else:
52+
return data
4553

4654

4755
def filter_obs(
@@ -374,6 +382,51 @@ def filter_blacklist(
374382
)
375383

376384

385+
def check_chrom_ends(
386+
data: Union[pd.DataFrame, AnnData],
387+
genome: Optional[str] = None,
388+
):
389+
"""
390+
Check that intervals do not exceed the ends of the chromosome.
391+
392+
Args:
393+
data: Either a pandas dataframe of genomic intervals or an Anndata
394+
object with intervals in .var
395+
genome: name of the genome corresponding to intervals
396+
397+
Raises:
398+
ValueError if any interval exceeds the chtomosome ends
399+
"""
400+
from grelu.io.genome import read_sizes
401+
402+
# Get genomic intervals
403+
if isinstance(data, AnnData):
404+
intervals = data.var
405+
elif isinstance(data, pd.DataFrame):
406+
intervals = data
407+
408+
# Check start
409+
fail = intervals[intervals.start < 0].index
410+
411+
# Filter end if the genome is provided
412+
if genome is None:
413+
warnings.warn(
414+
"No genome is provided; only intervals with negative start values will be flagged."
415+
)
416+
else:
417+
sizes = read_sizes(genome)
418+
for chrom, size in sizes.values:
419+
fail = fail.append(
420+
intervals[(intervals.chrom == chrom) & (intervals.end > size)].index
421+
)
422+
423+
fail = np.unique(fail)
424+
if len(fail) > 0:
425+
raise ValueError(
426+
f"Indices of intervals that extend beyond the chromosome ends: {','.join(fail.astype(str))}."
427+
)
428+
429+
377430
def filter_chrom_ends(
378431
data: Union[pd.DataFrame, AnnData],
379432
genome: Optional[str] = None,

tests/test_preprocess.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55

66
from grelu.data.preprocess import (
7+
check_chrom_ends,
78
filter_blacklist,
89
filter_cells,
910
filter_chrom_ends,
@@ -117,22 +118,36 @@ def test_filter_blacklist():
117118
assert filter_blacklist(intervals, genome="hg38").equals(intervals.iloc[-2:, :])
118119

119120

121+
chrom_end_intervals = pd.DataFrame(
122+
{
123+
"chrom": ["chr1", "chr1", "chr1", "chr1", "chr1"],
124+
"start": [-10, 10, 1000, 248956300, 248956350],
125+
"end": [90, 110, 1100, 248956400, 248956450],
126+
}
127+
)
128+
129+
120130
def test_filter_chrom_ends():
121-
intervals = pd.DataFrame(
122-
{
123-
"chrom": ["chr1", "chr1", "chr1", "chr1", "chr1"],
124-
"start": [-10, 10, 1000, 248956300, 248956350],
125-
"end": [90, 110, 1100, 248956400, 248956450],
126-
}
127-
)
128-
assert filter_chrom_ends(intervals, genome="hg38").equals(
129-
intervals.iloc[[1, 2, 3], :]
131+
132+
assert filter_chrom_ends(chrom_end_intervals, genome="hg38").equals(
133+
chrom_end_intervals.iloc[[1, 2, 3], :]
130134
)
131-
assert filter_chrom_ends(intervals, genome="hg38", pad=100).equals(
132-
intervals.iloc[[2], :]
135+
assert filter_chrom_ends(chrom_end_intervals, genome="hg38", pad=100).equals(
136+
chrom_end_intervals.iloc[[2], :]
133137
)
134138

135139

140+
def test_check_chrom_ends():
141+
with pytest.raises(Exception) as e_info:
142+
check_chrom_ends(chrom_end_intervals, genome="hg38")
143+
assert (
144+
str(e_info.value)
145+
== "Indices of intervals that extend beyond the chromosome ends: 0,4."
146+
)
147+
148+
check_chrom_ends(chrom_end_intervals.iloc[1:2], genome="hg38")
149+
150+
136151
def test_merge_intervals_by_column():
137152
intervals = pd.DataFrame(
138153
{

0 commit comments

Comments
 (0)