|
| 1 | +from typing import Union, Iterable, Iterator, List |
1 | 2 | import pytest |
2 | 3 | from conftest import vcf_file, gtf_file, example_intervals_bed |
3 | 4 | import pyranges |
4 | 5 | from kipoiseq.dataclasses import Interval, Variant |
5 | 6 | from kipoiseq.extractors.vcf import MultiSampleVCF |
6 | 7 | from kipoiseq.extractors.vcf_matching import variants_to_pyranges, \ |
7 | 8 | pyranges_to_intervals, intervals_to_pyranges, BaseVariantMatcher, \ |
8 | | - SingleVariantMatcher, MultiVariantsMatcher |
| 9 | + SingleVariantMatcher, MultiVariantsMatcher, VariantFetcher |
9 | 10 |
|
10 | 11 | intervals = [ |
11 | 12 | Interval('chr1', 1, 10, strand='+'), |
|
26 | 27 | ) |
27 | 28 |
|
28 | 29 |
|
| 30 | +class VariantFetcherProxy(VariantFetcher): |
| 31 | + |
| 32 | + def __init__(self, variant_fetcher: VariantFetcher): |
| 33 | + self.variant_fetcher = variant_fetcher |
| 34 | + |
| 35 | + def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]: |
| 36 | + yield from self.variant_fetcher.fetch_variants(interval) |
| 37 | + |
| 38 | + def batch_iter(self, batch_size=10000) -> Iterator[List[Variant]]: |
| 39 | + yield from self.variant_fetcher.batch_iter(batch_size) |
| 40 | + |
| 41 | + def __iter__(self) -> Iterator[Variant]: |
| 42 | + yield from self.variant_fetcher |
| 43 | + |
| 44 | + |
| 45 | +# make sure that kipoiseq only uses the VariantFetcher API |
| 46 | +read_variants_fn = BaseVariantMatcher._read_variants |
| 47 | + |
| 48 | + |
| 49 | +@staticmethod |
| 50 | +def proxy_fn(*args, **kwargs): |
| 51 | + vf = VariantFetcherProxy( |
| 52 | + read_variants_fn(*args, **kwargs) |
| 53 | + ) |
| 54 | + return vf |
| 55 | + |
| 56 | + |
| 57 | +BaseVariantMatcher._read_variants = proxy_fn |
| 58 | + |
| 59 | + |
29 | 60 | def test_variants_to_pyranges(): |
30 | 61 | vcf = MultiSampleVCF(vcf_file) |
31 | 62 | variants = list(vcf) |
@@ -140,6 +171,7 @@ def test_SingleVariantMatcher__iter__(): |
140 | 171 | assert (inters[2], variants[2]) in pairs |
141 | 172 | assert len(pairs) == 4 |
142 | 173 |
|
| 174 | + |
143 | 175 | def test_MultiVariantMatcher__iter__(): |
144 | 176 | matcher = MultiVariantsMatcher(vcf_file, intervals=intervals) |
145 | 177 | pairs = list(matcher) |
|
0 commit comments