Skip to content

Commit 89f621c

Browse files
authored
Merge pull request #100 from Hoeze/generalize_variant_source
Allow to directly annotate a list of variants
2 parents 6628345 + 35233e0 commit 89f621c

File tree

7 files changed

+182
-46
lines changed

7 files changed

+182
-46
lines changed

.circleci/config.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ variables:
117117
jobs:
118118
test-py36:
119119
docker:
120-
- image: continuumio/miniconda3:4.7.10
120+
- image: continuumio/miniconda3:4.10.3
121121
working_directory: ~/repo
122122
steps:
123123
- checkout
@@ -135,7 +135,7 @@ jobs:
135135

136136
build-deploy-docs:
137137
docker:
138-
- image: continuumio/miniconda3:4.7.10
138+
- image: continuumio/miniconda3:4.10.3
139139
working_directory: ~/repo
140140
steps:
141141
- add_ssh_keys:
@@ -168,7 +168,7 @@ jobs:
168168
169169
test-deploy-pypi:
170170
docker:
171-
- image: continuumio/miniconda3:4.7.10
171+
- image: continuumio/miniconda3:4.10.3
172172
working_directory: ~/repo
173173
steps:
174174
- checkout
@@ -197,7 +197,7 @@ jobs:
197197
198198
productive-deploy-pypi:
199199
docker:
200-
- image: continuumio/miniconda3:4.7.10
200+
- image: continuumio/miniconda3:4.10.3
201201
working_directory: ~/repo
202202
steps:
203203
- checkout

kipoiseq/extractors/vcf.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1+
import abc
12
import logging
2-
from typing import Tuple, Iterable, List, Union
3+
from typing import Tuple, Iterable, List, Union, Iterator
34
from itertools import islice
45
from collections import defaultdict
56
from kipoiseq.dataclasses import Variant, Interval
7+
from kipoiseq.variant_source import VariantFetcher
68
from kipoiseq.extractors.vcf_query import VariantIntervalQueryable
9+
from kipoiseq.utils import batch_iter
710

811
try:
912
from cyvcf2 import VCF
@@ -15,19 +18,11 @@
1518
]
1619

1720

18-
def _batch_iter(variants: Iterable[Variant], batch_size=10000
19-
) -> Iterable[Iterable[Variant]]:
20-
batch = list(islice(variants, batch_size))
21-
while batch:
22-
yield batch
23-
batch = list(islice(variants, batch_size))
24-
25-
26-
class MultiSampleVCF(VCF):
21+
class MultiSampleVCF(VariantFetcher, VCF):
2722

2823
def __init__(self, *args, **kwargs):
2924
from cyvcf2 import VCF
30-
super(MultiSampleVCF, self).__init__(*args, **kwargs, strict_gt=True)
25+
VCF.__init__(self, *args, **kwargs, strict_gt=True)
3126
self.sample_mapping = dict(zip(self.samples, range(len(self.samples))))
3227

3328
def fetch_variants(self, interval, sample_id=None):
@@ -79,7 +74,7 @@ def batch_iter(self, batch_size=10000) -> Iterable[Iterable[Variant]]:
7974
batch_size: size of each batch.
8075
"""
8176
variants = iter(self)
82-
yield from _batch_iter(variants, batch_size=batch_size)
77+
yield from batch_iter(variants, batch_size=batch_size)
8378

8479
def query_variants(self, intervals: List[Interval], sample_id=None,
8580
progress=False) -> VariantIntervalQueryable:

kipoiseq/extractors/vcf_matching.py

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List
1+
import os
2+
from typing import List, Union, Iterable, Iterator
23
import pandas as pd
34
from kipoiseq.dataclasses import Variant, Interval
4-
from kipoiseq.extractors import MultiSampleVCF
5+
from kipoiseq.variant_source import VariantFetcher
56

67
try:
78
from pyranges import PyRanges
@@ -82,14 +83,43 @@ def intervals_to_pyranges(intervals):
8283
)
8384

8485

86+
class PyrangesVariantFetcher(VariantFetcher):
87+
88+
def __init__(self, variants: List[Variant]):
89+
self.variants = variants
90+
self._variants_pr = None
91+
92+
@property
93+
def variants_pr(self):
94+
# convert to PyRanges on demand
95+
if self._variants_pr is None:
96+
self._variants_pr = variants_to_pyranges(self.variants)
97+
return self._variants_pr
98+
99+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
100+
if isinstance(interval, Interval):
101+
interval = [interval]
102+
# convert interval(s) to PyRanges object
103+
interval_pr: PyRanges = intervals_to_pyranges(interval)
104+
# join with variants
105+
pr_join = interval_pr.join(self.variants_pr, suffix='_variant')
106+
107+
yield from pr_join.df["variant"]
108+
109+
def __iter__(self) -> Iterator[Variant]:
110+
yield from self.variants
111+
112+
85113
class BaseVariantMatcher:
86114
"""
87115
Base variant intervals matcher
88116
"""
89117

90118
def __init__(
91119
self,
92-
vcf_file: str,
120+
vcf_file: str = None,
121+
variants: List[Variant] = None,
122+
variant_fetcher: VariantFetcher = None,
93123
gtf_path: str = None,
94124
bed_path: str = None,
95125
pranges: PyRanges = None,
@@ -101,7 +131,8 @@ def __init__(
101131
"""
102132
103133
Args:
104-
vcf_file: path of vcf file
134+
vcf_file: (optional) path of vcf file
135+
variants: (optional) readily processed variants
105136
gtf_path: (optional) path of gtf file contains features
106137
bed_path: (optional) path of bed file
107138
pranges: (optional) pyranges object
@@ -110,12 +141,31 @@ def __init__(
110141
pyranges object. This argument is not valid with intervals.
111142
Currently unused
112143
"""
113-
self.vcf = MultiSampleVCF(vcf_file, lazy=vcf_lazy)
144+
self.variant_fetcher = self._read_variants(vcf_file, variants, variant_fetcher, vcf_lazy)
114145
self.interval_attrs = interval_attrs
115146
self.pr = self._read_intervals(gtf_path, bed_path, pranges,
116147
intervals, interval_attrs, duplicate_attr=True)
117148
self.variant_batch_size = variant_batch_size
118149

150+
@staticmethod
151+
def _read_variants(
152+
vcf_file=None,
153+
variants=None,
154+
variant_fetcher=None,
155+
vcf_lazy: bool = True,
156+
) -> VariantFetcher:
157+
if vcf_file is not None:
158+
from kipoiseq.extractors import MultiSampleVCF
159+
return MultiSampleVCF(vcf_file, lazy=vcf_lazy)
160+
elif variant_fetcher is not None:
161+
assert isinstance(variant_fetcher, VariantFetcher), \
162+
"Wrong type of variant fetcher: %s" % type(variant_fetcher)
163+
return variant_fetcher
164+
elif variants is not None:
165+
return PyrangesVariantFetcher(variants)
166+
else:
167+
raise ValueError("No source of variants was specified!")
168+
119169
@staticmethod
120170
def _read_intervals(gtf_path=None, bed_path=None, pranges=None,
121171
intervals=None, interval_attrs=None, duplicate_attr=False):
@@ -172,7 +222,7 @@ def _read_vcf_pyranges(self, batch_size=10000):
172222
Args:
173223
batch_size: size of each batch.
174224
"""
175-
for batch in self.vcf.batch_iter(batch_size):
225+
for batch in self.variant_fetcher.batch_iter(batch_size):
176226
yield variants_to_pyranges(batch)
177227

178228
def iter_pyranges(self) -> PyRanges:
@@ -210,30 +260,14 @@ def __iter__(self) -> (Interval, Variant):
210260

211261
class MultiVariantsMatcher(BaseVariantMatcher):
212262

213-
def __init__(
214-
self,
215-
vcf_file,
216-
gtf_path=None,
217-
pranges=None,
218-
intervals=None,
219-
interval_attrs=None,
220-
vcf_lazy=True,
221-
variant_batch_size=10000
222-
):
223-
super().__init__(
224-
vcf_file,
225-
gtf_path=gtf_path,
226-
pranges=pranges,
227-
intervals=intervals,
228-
interval_attrs=interval_attrs,
229-
vcf_lazy=vcf_lazy,
230-
variant_batch_size=variant_batch_size
231-
)
263+
def __init__(self, *args, **kwargs):
264+
super().__init__(*args, **kwargs)
265+
232266
if hasattr(self.pr, 'intervals'):
233267
self.intervals = self.pr.intervals
234268
else:
235269
self.intervals = pyranges_to_intervals(self.pr)
236270

237271
def __iter__(self):
238272
for i in self.intervals:
239-
yield i, self.vcf.fetch_variants(i)
273+
yield i, self.variant_fetcher.fetch_variants(i)

kipoiseq/utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
1+
from itertools import islice
2+
from typing import Iterable, TypeVar
3+
14
import numpy as np
25
from six import string_types
36

47
# alphabets:
8+
from kipoiseq import Variant
9+
510
DNA = ["A", "C", "G", "T"]
611
RNA = ["A", "C", "G", "U"]
712
AMINO_ACIDS = ["A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
@@ -36,3 +41,18 @@ def parse_dtype(dtype):
3641
raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e))
3742
else:
3843
return dtype
44+
45+
46+
T = TypeVar('T')
47+
48+
49+
def batch_iter(items: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]:
50+
# ensure this is an iterator
51+
item_iter = iter(items)
52+
while True:
53+
# create next `batch_size` number of items;
54+
batch = list(islice(item_iter, batch_size))
55+
if len(batch) == 0:
56+
break
57+
58+
yield batch

kipoiseq/variant_source.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import abc
2+
from typing import Iterable, Union, Iterator
3+
4+
from kipoiseq import Interval, Variant
5+
from kipoiseq.utils import batch_iter
6+
7+
8+
class VariantFetcher(Iterable, metaclass=abc.ABCMeta):
9+
"""
10+
Base class for all variant-returning data sources.
11+
"""
12+
13+
@abc.abstractmethod
14+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
15+
"""
16+
Fetch variants that intersect the provided interval(s).
17+
18+
:param interval: One or multiple Interval objects
19+
:return: Iterator of Variant objects that intersect `interval`
20+
"""
21+
raise NotImplementedError
22+
23+
def batch_iter(self, batch_size=10000) -> Iterator[Iterable[Variant]]:
24+
"""
25+
Fetch variants in batches.
26+
27+
:param batch_size: Number of variants per batch
28+
:return: Iterator that yields batches of Variant objects
29+
"""
30+
yield from batch_iter(self, batch_size)
31+
32+
@abc.abstractmethod
33+
def __iter__(self) -> Iterator[Variant]:
34+
"""
35+
Fetch variants in batches.
36+
37+
:return: Iterator of Variant objects
38+
"""
39+
raise NotImplementedError

setup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@
3434
"cython",
3535
"cyvcf2",
3636
"pyranges>=0.0.71",
37-
"keras",
38-
"tensorflow",
37+
# "keras",
38+
# "tensorflow",
3939
"pybedtools",
40-
"concise"
40+
# "concise"
4141
]
4242

4343
setup(

tests/extractors/test_vcf_matching.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
1+
from typing import Union, Iterable, Iterator, List
12
import pytest
23
from conftest import vcf_file, gtf_file, example_intervals_bed
34
import pyranges
45
from kipoiseq.dataclasses import Interval, Variant
56
from kipoiseq.extractors.vcf import MultiSampleVCF
67
from kipoiseq.extractors.vcf_matching import variants_to_pyranges, \
78
pyranges_to_intervals, intervals_to_pyranges, BaseVariantMatcher, \
8-
SingleVariantMatcher, MultiVariantsMatcher
9+
SingleVariantMatcher, MultiVariantsMatcher, VariantFetcher
910

1011
intervals = [
1112
Interval('chr1', 1, 10, strand='+'),
@@ -26,6 +27,36 @@
2627
)
2728

2829

30+
class VariantFetcherProxy(VariantFetcher):
31+
32+
def __init__(self, variant_fetcher: VariantFetcher):
33+
self.variant_fetcher = variant_fetcher
34+
35+
def fetch_variants(self, interval: Union[Interval, Iterable[Interval]]) -> Iterator[Variant]:
36+
yield from self.variant_fetcher.fetch_variants(interval)
37+
38+
def batch_iter(self, batch_size=10000) -> Iterator[List[Variant]]:
39+
yield from self.variant_fetcher.batch_iter(batch_size)
40+
41+
def __iter__(self) -> Iterator[Variant]:
42+
yield from self.variant_fetcher
43+
44+
45+
# make sure that kipoiseq only uses the VariantFetcher API
46+
read_variants_fn = BaseVariantMatcher._read_variants
47+
48+
49+
@staticmethod
50+
def proxy_fn(*args, **kwargs):
51+
vf = VariantFetcherProxy(
52+
read_variants_fn(*args, **kwargs)
53+
)
54+
return vf
55+
56+
57+
BaseVariantMatcher._read_variants = proxy_fn
58+
59+
2960
def test_variants_to_pyranges():
3061
vcf = MultiSampleVCF(vcf_file)
3162
variants = list(vcf)
@@ -131,6 +162,15 @@ def test_SingleVariantMatcher__iter__():
131162
assert (inters[2], variants[2]) in pairs
132163
assert len(pairs) == 4
133164

165+
matcher = SingleVariantMatcher(variants=variants, pranges=pr)
166+
pairs = list(matcher)
167+
168+
assert (inters[0], variants[0]) in pairs
169+
assert (inters[0], variants[1]) in pairs
170+
assert (inters[1], variants[2]) in pairs
171+
assert (inters[2], variants[2]) in pairs
172+
assert len(pairs) == 4
173+
134174

135175
def test_MultiVariantMatcher__iter__():
136176
matcher = MultiVariantsMatcher(vcf_file, intervals=intervals)
@@ -148,3 +188,11 @@ def test_MultiVariantMatcher__iter__():
148188
assert list(pairs[0][1]) == [variants[0], variants[1]]
149189
assert pairs[1][0] == intervals[1]
150190
assert list(pairs[1][1]) == [variants[2]]
191+
192+
matcher = MultiVariantsMatcher(variants=variants, pranges=pr)
193+
pairs = list(matcher)
194+
195+
assert pairs[0][0] == intervals[0]
196+
assert list(pairs[0][1]) == [variants[0], variants[1]]
197+
assert pairs[1][0] == intervals[1]
198+
assert list(pairs[1][1]) == [variants[2]]

0 commit comments

Comments
 (0)