Skip to content

Commit 35233e0

Browse files
committed
use pytest-mock for the VariantFetcherProxy
1 parent 648598f commit 35233e0

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

kipoiseq/extractors/vcf_matching.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,21 +110,6 @@ def __iter__(self) -> Iterator[Variant]:
110110
yield from self.variants
111111

112112

113-
class VariantFetcherProxy(VariantFetcher):
114-
115-
def __init__(self, variant_fetcher: VariantFetcher):
116-
self.variant_fetcher = variant_fetcher
117-
118-
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
119-
yield from self.variant_fetcher.fetch_variants(interval)
120-
121-
def batch_iter(self, batch_size=10000) -> Iterator[List[Variant]]:
122-
yield from self.variant_fetcher.batch_iter(batch_size)
123-
124-
def __iter__(self) -> Iterator[Variant]:
125-
yield from self.variant_fetcher
126-
127-
128113
class BaseVariantMatcher:
129114
"""
130115
Base variant intervals matcher
@@ -168,16 +153,10 @@ def _read_variants(
168153
variants=None,
169154
variant_fetcher=None,
170155
vcf_lazy: bool = True,
171-
):
156+
) -> VariantFetcher:
172157
if vcf_file is not None:
173158
from kipoiseq.extractors import MultiSampleVCF
174-
vcf = MultiSampleVCF(vcf_file, lazy=vcf_lazy)
175-
176-
if os.environ.get('PYTEST_RUNNING', '') == 'true':
177-
# Ensure that none of the methods actually uses MultiSampleVCF methods by accident
178-
vcf = VariantFetcherProxy(vcf)
179-
180-
return vcf
159+
return MultiSampleVCF(vcf_file, lazy=vcf_lazy)
181160
elif variant_fetcher is not None:
182161
assert isinstance(variant_fetcher, VariantFetcher), \
183162
"Wrong type of variant fetcher: %s" % type(variant_fetcher)

tests/extractors/test_vcf_matching.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from typing import Union, Iterable, Iterator, List
12
import pytest
23
from conftest import vcf_file, gtf_file, example_intervals_bed
34
import pyranges
45
from kipoiseq.dataclasses import Interval, Variant
56
from kipoiseq.extractors.vcf import MultiSampleVCF
67
from kipoiseq.extractors.vcf_matching import variants_to_pyranges, \
78
pyranges_to_intervals, intervals_to_pyranges, BaseVariantMatcher, \
8-
SingleVariantMatcher, MultiVariantsMatcher
9+
SingleVariantMatcher, MultiVariantsMatcher, VariantFetcher
910

1011
intervals = [
1112
Interval('chr1', 1, 10, strand='+'),
@@ -26,6 +27,36 @@
2627
)
2728

2829

30+
class VariantFetcherProxy(VariantFetcher):
31+
32+
def __init__(self, variant_fetcher: VariantFetcher):
33+
self.variant_fetcher = variant_fetcher
34+
35+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
36+
yield from self.variant_fetcher.fetch_variants(interval)
37+
38+
def batch_iter(self, batch_size=10000) -> Iterator[List[Variant]]:
39+
yield from self.variant_fetcher.batch_iter(batch_size)
40+
41+
def __iter__(self) -> Iterator[Variant]:
42+
yield from self.variant_fetcher
43+
44+
45+
# make sure that kipoiseq only uses the VariantFetcher API
46+
read_variants_fn = BaseVariantMatcher._read_variants
47+
48+
49+
@staticmethod
50+
def proxy_fn(*args, **kwargs):
51+
vf = VariantFetcherProxy(
52+
read_variants_fn(*args, **kwargs)
53+
)
54+
return vf
55+
56+
57+
BaseVariantMatcher._read_variants = proxy_fn
58+
59+
2960
def test_variants_to_pyranges():
3061
vcf = MultiSampleVCF(vcf_file)
3162
variants = list(vcf)
@@ -140,6 +171,7 @@ def test_SingleVariantMatcher__iter__():
140171
assert (inters[2], variants[2]) in pairs
141172
assert len(pairs) == 4
142173

174+
143175
def test_MultiVariantMatcher__iter__():
144176
matcher = MultiVariantsMatcher(vcf_file, intervals=intervals)
145177
pairs = list(matcher)

0 commit comments

Comments
 (0)