Skip to content

Commit c6ce5cb

Browse files
authored
Merge pull request #160 from Genentech/ad-input
fixed filter_overlapping with anndata input
2 parents 0f702f2 + 04aafa9 commit c6ce5cb

File tree

2 files changed

+77
-9
lines changed

2 files changed

+77
-9
lines changed

src/grelu/data/preprocess.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,10 @@ def filter_overlapping(
296296
from grelu.sequence.format import check_intervals
297297
from grelu.variant import variants_to_intervals
298298

299+
# Get genomic intervals
299300
if isinstance(data, AnnData):
300-
intervals = data.var
301+
intervals = data.var.reset_index(drop=True)
302+
301303
elif isinstance(data, pd.DataFrame):
302304
if check_intervals(data):
303305
intervals = data
@@ -319,14 +321,16 @@ def filter_overlapping(
319321
bf.expand(ref_intervals, pad=window),
320322
how="inner",
321323
return_index=True,
322-
return_input=True,
323324
)
324325
overlap = overlap[
325326
(overlap.start >= overlap.start_) & ((overlap.end <= overlap.end_))
326327
]
327328

328329
# list intervals to keep
329-
keep = intervals.index.isin(overlap["index"])
330+
if isinstance(data, AnnData):
331+
keep = data.var.index.isin(data.var.index[overlap["index"].values])
332+
else:
333+
keep = intervals.index.isin(overlap["index"])
330334
if invert:
331335
keep = ~keep
332336

tests/test_preprocess.py

Lines changed: 70 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,12 @@ def test_filter_cells():
8383

8484

8585
def test_filter_overlapping():
86+
# Test data for 'any' overlap method
8687
intervals = pd.DataFrame(
8788
{
88-
"chrom": ["chr10", "chr10", "chr10"],
89-
"start": [10, 1000, 45000],
90-
"end": [1010, 2000, 46000],
89+
"chrom": ["chr10", "chr10", "chr10", "chr10"],
90+
"start": [10, 150, 1000, 45000],
91+
"end": [1010, 180, 2000, 46000],
9192
}
9293
)
9394
ref_intervals = pd.DataFrame(
@@ -98,14 +99,41 @@ def test_filter_overlapping():
9899
}
99100
)
100101

102+
# Test with DataFrame input
103+
104+
# method='any'
101105
# No window, overlapping
102-
assert filter_overlapping(intervals, ref_intervals).equals(intervals.iloc[[0], :])
106+
assert filter_overlapping(intervals, ref_intervals, method="any"
107+
).equals(intervals.iloc[[0, 1], :])
103108

109+
# method='any'
104110
# Window, non-overlapping
105-
assert filter_overlapping(intervals, ref_intervals, window=50, invert=True).equals(
106-
intervals.iloc[[2], :]
111+
assert filter_overlapping(intervals, ref_intervals, window=50, invert=True,
112+
method="any").equals(intervals.iloc[[3], :])
113+
114+
# method='all'
115+
assert filter_overlapping(intervals, ref_intervals, method="all").equals(
116+
intervals.iloc[[1], :]
107117
)
108118

119+
# Test with anndata input
120+
ad = anndata.AnnData(np.random.rand(2, 4), dtype=np.float32)
121+
ad.var = intervals.copy()
122+
ad.var.index = ad.var.index.astype(str)
123+
124+
# method='any'
125+
# AnnData, No window, overlapping
126+
ad_filtered = filter_overlapping(ad, ref_intervals, method="any")
127+
assert ad_filtered.var.equals(ad.var.iloc[[0, 1], :])
128+
129+
# AnnData, Window, non-overlapping
130+
ad_filtered = filter_overlapping(ad, ref_intervals, window=50, invert=True, method="any")
131+
assert ad_filtered.var.equals(ad.var.iloc[[3], :])
132+
133+
# method='all'
134+
ad_filtered = filter_overlapping(ad, ref_intervals, method="all")
135+
assert ad_filtered.var.equals(ad.var.iloc[[1], :])
136+
109137

110138
def test_filter_blacklist():
111139
intervals = pd.DataFrame(
@@ -115,8 +143,17 @@ def test_filter_blacklist():
115143
"end": [1010, 2000, 46000, 47000, 49000],
116144
}
117145
)
146+
147+
# DataFrame input
118148
assert filter_blacklist(intervals, genome="hg38").equals(intervals.iloc[-2:, :])
119149

150+
# AnnData input
151+
ad = anndata.AnnData(np.random.rand(2, 5), dtype=np.float32)
152+
ad.var = intervals.copy()
153+
ad.var.index = ad.var.index.astype(str)
154+
ad_filtered = filter_blacklist(ad, genome="hg38")
155+
assert ad_filtered.var.equals(ad.var.iloc[-2:, :])
156+
120157

121158
chrom_end_intervals = pd.DataFrame(
122159
{
@@ -129,15 +166,29 @@ def test_filter_blacklist():
129166

130167
def test_filter_chrom_ends():
131168

169+
# DataFrame input
132170
assert filter_chrom_ends(chrom_end_intervals, genome="hg38").equals(
133171
chrom_end_intervals.iloc[[1, 2, 3], :]
134172
)
135173
assert filter_chrom_ends(chrom_end_intervals, genome="hg38", pad=100).equals(
136174
chrom_end_intervals.iloc[[2], :]
137175
)
138176

177+
# AnnData input
178+
ad = anndata.AnnData(np.random.rand(2, 5), dtype=np.float32)
179+
ad.var = chrom_end_intervals.copy()
180+
ad.var.index = ad.var.index.astype(str)
181+
182+
ad_filtered = filter_chrom_ends(ad, genome="hg38")
183+
assert ad_filtered.var.equals(ad.var.iloc[[1, 2, 3], :])
184+
185+
ad_filtered = filter_chrom_ends(ad, genome="hg38", pad=100)
186+
assert ad_filtered.var.equals(ad.var.iloc[[2], :])
187+
139188

140189
def test_check_chrom_ends():
190+
191+
# DataFrame input
141192
with pytest.raises(Exception) as e_info:
142193
check_chrom_ends(chrom_end_intervals, genome="hg38")
143194
assert (
@@ -147,6 +198,19 @@ def test_check_chrom_ends():
147198

148199
check_chrom_ends(chrom_end_intervals.iloc[1:2], genome="hg38")
149200

201+
# AnnData input
202+
ad = anndata.AnnData(np.random.rand(2, 5), dtype=np.float32)
203+
ad.var = chrom_end_intervals.copy()
204+
ad.var.index = ad.var.index.astype(str)
205+
206+
with pytest.raises(Exception) as e_info:
207+
check_chrom_ends(ad, genome="hg38")
208+
assert (
209+
str(e_info.value)
210+
== "Indices of intervals that extend beyond the chromosome ends: 0,4."
211+
)
212+
check_chrom_ends(ad[:, 1:2], genome="hg38")
213+
150214

151215
def test_merge_intervals_by_column():
152216
intervals = pd.DataFrame(

0 commit comments

Comments
 (0)