Skip to content

Commit 3585b92

Browse files
committed
Generalize variant source in VCF matching
1 parent b5c9b5f commit 3585b92

File tree

5 files changed

+166
-41
lines changed

5 files changed

+166
-41
lines changed

kipoiseq/extractors/vcf.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import abc
12
import logging
2-
from typing import Tuple, Iterable, List, Union
3+
from typing import Tuple, Iterable, List, Union, Iterator
34
from itertools import islice
45
from collections import defaultdict
56
from kipoiseq.dataclasses import Variant, Interval
7+
from kipoiseq.variant_source import VariantFetcher
68
from kipoiseq.extractors.vcf_query import VariantIntervalQueryable
9+
from kipoiseq.utils import batch_iter
710

811
try:
912
from cyvcf2 import VCF
@@ -15,19 +18,11 @@
1518
]
1619

1720

18-
def _batch_iter(variants: Iterable[Variant], batch_size=10000
19-
) -> Iterable[Iterable[Variant]]:
20-
batch = list(islice(variants, batch_size))
21-
while batch:
22-
yield batch
23-
batch = list(islice(variants, batch_size))
24-
25-
26-
class MultiSampleVCF(VCF):
21+
class MultiSampleVCF(VariantFetcher, VCF):
2722

2823
def __init__(self, *args, **kwargs):
2924
from cyvcf2 import VCF
30-
super(MultiSampleVCF, self).__init__(*args, **kwargs, strict_gt=True)
25+
VCF.__init__(self, *args, **kwargs, strict_gt=True)
3126
self.sample_mapping = dict(zip(self.samples, range(len(self.samples))))
3227

3328
def fetch_variants(self, interval, sample_id=None):
@@ -79,7 +74,7 @@ def batch_iter(self, batch_size=10000) -> Iterable[Iterable[Variant]]:
7974
batch_size: size of each batch.
8075
"""
8176
variants = iter(self)
82-
yield from _batch_iter(variants, batch_size=batch_size)
77+
yield from batch_iter(variants, batch_size=batch_size)
8378

8479
def query_variants(self, intervals: List[Interval], sample_id=None,
8580
progress=False) -> VariantIntervalQueryable:

kipoiseq/extractors/vcf_matching.py

Lines changed: 81 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List
1+
import os
2+
from typing import List, Union, Iterable, Iterator
23
import pandas as pd
34
from kipoiseq.dataclasses import Variant, Interval
4-
from kipoiseq.extractors import MultiSampleVCF
5+
from kipoiseq.variant_source import VariantFetcher
56

67
try:
78
from pyranges import PyRanges
@@ -82,14 +83,58 @@ def intervals_to_pyranges(intervals):
8283
)
8384

8485

86+
class PyrangesVariantFetcher(VariantFetcher):
87+
88+
def __init__(self, variants: List[Variant]):
89+
self.variants = variants
90+
self._variants_pr = None
91+
92+
@property
93+
def variants_pr(self):
94+
# convert to PyRanges on demand
95+
if self._variants_pr is None:
96+
self._variants_pr = variants_to_pyranges(self.variants)
97+
return self._variants_pr
98+
99+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
100+
if isinstance(interval, Interval):
101+
interval = [interval]
102+
# convert interval(s) to PyRanges object
103+
interval_pr: PyRanges = intervals_to_pyranges(interval)
104+
# join with variants
105+
pr_join = interval_pr.join(self.variants_pr, suffix='_variant')
106+
107+
yield from pr_join.df["variant"]
108+
109+
def __iter__(self) -> Iterator[Variant]:
110+
yield from self.variants
111+
112+
113+
class VariantFetcherProxy(VariantFetcher):
114+
115+
def __init__(self, variant_fetcher: VariantFetcher):
116+
self.variant_fetcher = variant_fetcher
117+
118+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
119+
yield from self.variant_fetcher.fetch_variants(interval)
120+
121+
def batch_iter(self, batch_size=10000) -> Iterator[List[Variant]]:
122+
yield from self.variant_fetcher.batch_iter(batch_size)
123+
124+
def __iter__(self) -> Iterator[Variant]:
125+
yield from self.variant_fetcher
126+
127+
85128
class BaseVariantMatcher:
86129
"""
87130
Base variant intervals matcher
88131
"""
89132

90133
def __init__(
91134
self,
92-
vcf_file: str,
135+
vcf_file: str = None,
136+
variants: List[Variant] = None,
137+
variant_fetcher: VariantFetcher = None,
93138
gtf_path: str = None,
94139
bed_path: str = None,
95140
pranges: PyRanges = None,
@@ -101,7 +146,8 @@ def __init__(
101146
"""
102147
103148
Args:
104-
vcf_file: path of vcf file
149+
vcf_file: (optional) path of vcf file
150+
variants: (optional) readily processed variants
105151
gtf_path: (optional) path of gtf file contains features
106152
bed_path: (optional) path of bed file
107153
pranges: (optional) pyranges object
@@ -110,12 +156,37 @@ def __init__(
110156
pyranges object. This argument is not valid with intervals.
111157
Currently unused
112158
"""
113-
self.vcf = MultiSampleVCF(vcf_file, lazy=vcf_lazy)
159+
self.variant_fetcher = self._read_variants(vcf_file, variants, variant_fetcher, vcf_lazy)
114160
self.interval_attrs = interval_attrs
115161
self.pr = self._read_intervals(gtf_path, bed_path, pranges,
116162
intervals, interval_attrs, duplicate_attr=True)
117163
self.variant_batch_size = variant_batch_size
118164

165+
@staticmethod
166+
def _read_variants(
167+
vcf_file=None,
168+
variants=None,
169+
variant_fetcher=None,
170+
vcf_lazy: bool = True,
171+
):
172+
if vcf_file is not None:
173+
from kipoiseq.extractors import MultiSampleVCF
174+
vcf = MultiSampleVCF(vcf_file, lazy=vcf_lazy)
175+
176+
if os.environ.get('PYTEST_RUNNING', '') == 'true':
177+
# Ensure that none of the methods actually uses MultiSampleVCF methods by accident
178+
vcf = VariantFetcherProxy(vcf)
179+
180+
return vcf
181+
elif variant_fetcher is not None:
182+
assert isinstance(variant_fetcher, VariantFetcher), \
183+
"Wrong type of variant fetcher: %s" % type(variant_fetcher)
184+
return variant_fetcher
185+
elif variants is not None:
186+
return PyrangesVariantFetcher(variants)
187+
else:
188+
raise ValueError("No source of variants was specified!")
189+
119190
@staticmethod
120191
def _read_intervals(gtf_path=None, bed_path=None, pranges=None,
121192
intervals=None, interval_attrs=None, duplicate_attr=False):
@@ -172,7 +243,7 @@ def _read_vcf_pyranges(self, batch_size=10000):
172243
Args:
173244
batch_size: size of each batch.
174245
"""
175-
for batch in self.vcf.batch_iter(batch_size):
246+
for batch in self.variant_fetcher.batch_iter(batch_size):
176247
yield variants_to_pyranges(batch)
177248

178249
def iter_pyranges(self) -> PyRanges:
@@ -210,30 +281,14 @@ def __iter__(self):
210281

211282
class MultiVariantsMatcher(BaseVariantMatcher):
212283

213-
def __init__(
214-
self,
215-
vcf_file,
216-
gtf_path=None,
217-
pranges=None,
218-
intervals=None,
219-
interval_attrs=None,
220-
vcf_lazy=True,
221-
variant_batch_size=10000
222-
):
223-
super().__init__(
224-
vcf_file,
225-
gtf_path=gtf_path,
226-
pranges=pranges,
227-
intervals=intervals,
228-
interval_attrs=interval_attrs,
229-
vcf_lazy=vcf_lazy,
230-
variant_batch_size=variant_batch_size
231-
)
284+
def __init__(self, *args, **kwargs):
285+
super().__init__(*args, **kwargs)
286+
232287
if hasattr(self.pr, 'intervals'):
233288
self.intervals = self.pr.intervals
234289
else:
235290
self.intervals = pyranges_to_intervals(self.pr)
236291

237292
def __iter__(self):
238293
for i in self.intervals:
239-
yield i, self.vcf.fetch_variants(i)
294+
yield i, self.variant_fetcher.fetch_variants(i)

kipoiseq/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from itertools import islice
2+
from typing import Iterable, TypeVar
3+
14
import numpy as np
25
from six import string_types
36

47
# alphabets:
8+
from kipoiseq import Variant
9+
510
DNA = ["A", "C", "G", "T"]
611
RNA = ["A", "C", "G", "U"]
712
AMINO_ACIDS = ["A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
@@ -36,3 +41,18 @@ def parse_dtype(dtype):
3641
raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e))
3742
else:
3843
return dtype
44+
45+
46+
T = TypeVar('T')
47+
48+
49+
def batch_iter(items: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
50+
# ensure this is an iterator
51+
item_iter = iter(items)
52+
while True:
53+
# create next `batch_size` number of items;
54+
batch = list(islice(item_iter, batch_size))
55+
if len(batch) == 0:
56+
break
57+
58+
yield batch

kipoiseq/variant_source.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import abc
2+
from typing import Iterable, Union, Iterator
3+
4+
from kipoiseq import Interval, Variant
5+
from kipoiseq.utils import batch_iter
6+
7+
8+
class VariantFetcher(Iterable, metaclass=abc.ABCMeta):
9+
"""
10+
Base class for all variant-returning data sources.
11+
"""
12+
13+
@abc.abstractmethod
14+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
15+
"""
16+
Fetch variants that intersect the provided interval(s).
17+
18+
:param interval: One or multiple Interval objects
19+
:return: Iterator of Variant objects that intersect `interval`
20+
"""
21+
raise NotImplementedError
22+
23+
def batch_iter(self, batch_size=10000) -> Iterator[Iterable[Variant]]:
24+
"""
25+
Fetch variants in batches.
26+
27+
:param batch_size: Number of variants per batch
28+
:return: Iterator that yields batches of Variant objects
29+
"""
30+
yield from batch_iter(self, batch_size)
31+
32+
@abc.abstractmethod
33+
def __iter__(self) -> Iterator[Variant]:
34+
"""
35+
Fetch variants in batches.
36+
37+
:return: Iterator of Variant objects
38+
"""
39+
raise NotImplementedError

tests/extractors/test_vcf_matching.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from kipoiseq.dataclasses import Interval, Variant
55
from kipoiseq.extractors.vcf import MultiSampleVCF
66
from kipoiseq.extractors.vcf_matching import variants_to_pyranges, \
7-
pyranges_to_intervals, intervals_to_pyranges, BaseVariantMatcher, \
7+
pyranges_to_intervals, intervals_to_pyranges, BaseVariantMatcher, \
88
SingleVariantMatcher, MultiVariantsMatcher
99

1010
intervals = [
@@ -112,8 +112,7 @@ def test_BaseVariantMatcher__read_intervals():
112112
def test_SingleVariantMatcher__iter__():
113113
inters = intervals + [Interval('chr1', 5, 50)]
114114

115-
matcher = SingleVariantMatcher(
116-
vcf_file, intervals=inters)
115+
matcher = SingleVariantMatcher(vcf_file, intervals=inters)
117116
pairs = list(matcher)
118117

119118
assert (inters[0], variants[0]) in pairs
@@ -134,6 +133,15 @@ def test_SingleVariantMatcher__iter__():
134133
assert (inters[2], variants[2]) in pairs
135134
assert len(pairs) == 5
136135

136+
matcher = SingleVariantMatcher(variants=variants, pranges=pr)
137+
pairs = list(matcher)
138+
139+
assert (inters[0], variants[0]) in pairs
140+
assert (inters[0], variants[1]) in pairs
141+
assert (inters[1], variants[2]) in pairs
142+
assert (inters[2], variants[1]) in pairs
143+
assert (inters[2], variants[2]) in pairs
144+
assert len(pairs) == 5
137145

138146
def test_MultiVariantMatcher__iter__():
139147
matcher = MultiVariantsMatcher(vcf_file, intervals=intervals)
@@ -151,3 +159,11 @@ def test_MultiVariantMatcher__iter__():
151159
assert list(pairs[0][1]) == [variants[0], variants[1]]
152160
assert pairs[1][0] == intervals[1]
153161
assert list(pairs[1][1]) == [variants[2]]
162+
163+
matcher = MultiVariantsMatcher(variants=variants, pranges=pr)
164+
pairs = list(matcher)
165+
166+
assert pairs[0][0] == intervals[0]
167+
assert list(pairs[0][1]) == [variants[0], variants[1]]
168+
assert pairs[1][0] == intervals[1]
169+
assert list(pairs[1][1]) == [variants[2]]

0 commit comments

Comments
 (0)