Skip to content

Commit 81bfb23

Browse files
committed
queryable-vcf-files
1 parent 19cce93 commit 81bfb23

File tree

5 files changed

+323
-3
lines changed

5 files changed

+323
-3
lines changed

kipoiseq/extractors/vcf_query.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
from typing import Tuple, Iterable, List
2+
from tqdm import tqdm
3+
from kipoiseq.dataclasses import Variant, Interval
4+
5+
6+
class VariantQuery:
7+
8+
def __init__(self, func):
9+
self.func = func
10+
11+
def __call__(self, variant: Variant):
12+
return self.func(variant)
13+
14+
def __or__(self, other):
15+
return VariantQuery(lambda variant: self(variant) or other(variant))
16+
17+
def __and__(self, other):
18+
return VariantQuery(lambda variant: self(variant) and other(variant))
19+
20+
21+
class FilterVariantQuery(VariantQuery):
22+
23+
def __init__(self, filter='PASS'):
24+
self.filter = filter
25+
26+
def __call__(self, variant):
27+
return variant.filter == self.filter
28+
29+
30+
class VariantIntervalQuery:
31+
32+
def __init__(self, func):
33+
self.func = func
34+
35+
def __call__(self, variants: List[Variant], interval: Interval):
36+
return self.func(variants, interval)
37+
38+
def __or__(self, other):
39+
return VariantIntervalQuery(
40+
lambda variants, interval: (
41+
i or j for i, j in zip(self(variants, interval),
42+
other(variants, interval))))
43+
44+
def __and__(self, other):
45+
return VariantIntervalQuery(
46+
lambda variants, interval: (
47+
i and j for i, j in zip(self(variants, interval),
48+
other(variants, interval))))
49+
50+
51+
class NumberVariantQuery(VariantIntervalQuery):
52+
"""
53+
Closure for variant query. Filter variants for interval
54+
if number of variants in given limits.
55+
"""
56+
57+
def __init__(self, max_num=float('inf'), min_num=0):
58+
# TODO: sample speficity
59+
self.max_num = max_num
60+
self.min_num = min_num
61+
62+
def __call__(self, variants, interval):
63+
if self.max_num >= len(variants) >= self.min_num:
64+
return [True] * len(variants)
65+
else:
66+
return [False] * len(variants)
67+
68+
69+
_VariantIntervalType = List[Tuple[Iterable[Variant], Interval]]
70+
71+
72+
class VariantIntervalQueryable:
73+
74+
def __init__(self, vcf, variant_intervals: _VariantIntervalType,
75+
progress=False):
76+
"""
77+
Query object of variants.
78+
79+
Args:
80+
vcf: cyvcf2.VCF objects.
81+
variants: iter of (variant, interval) tuples.
82+
"""
83+
self.vcf = vcf
84+
85+
if progress:
86+
self.variant_intervals = tqdm(variant_intervals)
87+
else:
88+
self.variant_intervals = variant_intervals
89+
90+
def __iter__(self):
91+
for variants, interval in self.variant_intervals:
92+
yield from variants
93+
94+
def filter(self, query: VariantQuery):
95+
"""
96+
Filters variant given conduction.
97+
98+
Args:
99+
query: function which get a variant as input and filtered iter of
100+
variants.
101+
"""
102+
self.variant_intervals = [
103+
(filter(query, variants), Interval)
104+
for variants, interval in self.variant_intervals
105+
]
106+
return self
107+
108+
def filter_range(self, query: VariantIntervalQuery):
109+
"""
110+
Filters variant given conduction.
111+
112+
Args:
113+
query: function which get variants and an interval as input
114+
and filtered iter of variants.
115+
"""
116+
self.variant_intervals = list(self._filter_range(query))
117+
return self
118+
119+
def _filter_range(self, query: VariantIntervalQuery):
120+
for variants, interval in self.variant_intervals:
121+
variants = list(variants)
122+
yield (
123+
v
124+
for v, cond in zip(variants, query(variants, interval))
125+
if cond
126+
), interval
127+
128+
def to_vcf(self, path):
129+
"""
130+
Parse query result as vcf file.
131+
132+
Args:
133+
path: path of the file.
134+
"""
135+
from cyvcf2 import Writer
136+
writer = Writer(path, self.vcf)
137+
for v in self:
138+
writer.write_record(v.source)

kipoiseq/extractors/vcf_seq.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from pyfaidx import Sequence, complement
22
from kipoiseq.extractors import BaseExtractor, FastaStringExtractor
33
from kipoiseq.dataclasses import Variant, Interval
4+
from kipoiseq.extractors.vcf_query import VariantIntervalQueryable
5+
46

57
try:
68
from cyvcf2 import VCF
@@ -33,8 +35,74 @@ def _region(self, interval):
3335

3436
def has_variant(self, variant, sample_id):
3537
gt_type = variant.source.gt_types[self.sample_mapping[sample_id]]
38+
return self._has_variant_gt(gt_type)
39+
40+
def _has_variant_gt(self, gt_type):
3641
return gt_type != 0 and gt_type != 2
3742

43+
def query_variants(self, intervals, sample_id=None, progress=False):
44+
"""
45+
Fetch variants for given multi-intervals from vcf file
46+
for sample if sample id is given.
47+
48+
Args:
49+
intervals (List[pybedtools.Interval]): list of Interval objects
50+
sample_id (str, optional): sample id in vcf file.
51+
52+
Returns:
53+
VCFQueryable: queryable object whihc allow you to query the
54+
fetched variatns.
55+
56+
Examples:
57+
To fetch variants if only single variant present in interval.
58+
59+
>>> MultiSampleVCF(vcf_path) \
60+
.query_variants(intervals) \
61+
.filter(lambda variant: variant.qual > 10) \
62+
.filter_range(NumberVariantQuery(max_num=1))
63+
.to_vcf(output_path)
64+
"""
65+
pairs = ((self.fetch_variants(i, sample_id=sample_id), i)
66+
for i in intervals)
67+
return VariantIntervalQueryable(self, pairs, progress=progress)
68+
69+
def get_variant(self, variant):
70+
"""
71+
Returns variant from vcf file. Let you use vcf file as dict.
72+
73+
Args:
74+
vcf: cyvcf2.VCF file
75+
variant: variant object or variant id as string.
76+
77+
Returns:
78+
Variant object.
79+
80+
Examples:
81+
>>> MultiSampleVCF(vcf_path).get_variant("chr1:4:T:['C']")
82+
"""
83+
if type(variant) == str:
84+
variant = Variant.from_str(variant)
85+
86+
variants = self.fetch_variants(
87+
Interval(variant.chrom, variant.pos, variant.pos))
88+
for v in variants:
89+
if v.ref == variant.ref and v.alt == variant.alt:
90+
return v
91+
raise KeyError('Variant %s not found in vcf file.' % str(variant))
92+
93+
def get_samples(self, variant):
94+
"""
95+
Fetchs sample names which have given variants
96+
97+
Args:
98+
variant: variant object.
99+
100+
Returns:
101+
Dict[str, int]: Dict of sample which have variant and gt as value.
102+
"""
103+
return dict(filter(lambda x: self._has_variant_gt(x[1]),
104+
zip(self.samples, variant.gt_types)))
105+
38106

39107
class IntervalSeqBuilder(list):
40108
"""

tests/conftest.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
import pytest
2+
3+
vcf_file = 'tests/data/test.vcf.gz'
4+
sample_5kb_fasta_file = 'tests/data/sample.5kb.fa'

tests/extractors/test_vcf_query.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import pytest
2+
from conftest import vcf_file
3+
from kipoiseq.dataclasses import Variant, Interval
4+
from kipoiseq.extractors.vcf_seq import MultiSampleVCF
5+
from kipoiseq.extractors.vcf_query import *
6+
7+
8+
@pytest.fixture
9+
def query_true():
10+
return VariantQuery(lambda v: True)
11+
12+
13+
@pytest.fixture
14+
def query_false():
15+
return VariantQuery(lambda v: False)
16+
17+
18+
def test_base_query__and__(query_false, query_true):
19+
assert not (query_false & query_true)(None)
20+
21+
22+
def test_base_query__or__(query_false, query_true):
23+
assert (query_false | query_true)(None)
24+
25+
26+
@pytest.fixture
27+
def variant_queryable():
28+
vcf = MultiSampleVCF(vcf_file)
29+
return VariantIntervalQueryable(vcf, [
30+
(
31+
[
32+
Variant('chr1', 12, 'A', 'T'),
33+
Variant('chr1', 18, 'A', 'C', filter='q10'),
34+
],
35+
Interval('chr1', 10, 20)
36+
),
37+
(
38+
[
39+
Variant('chr2', 120, 'AT', 'AAAT'),
40+
],
41+
Interval('chr2', 110, 200)
42+
)
43+
])
44+
45+
46+
def test_variant_queryable__iter__(variant_queryable):
47+
variants = list(variant_queryable)
48+
assert len(variants) == 3
49+
assert variants[0].ref == 'A'
50+
assert variants[0].alt == 'T'
51+
52+
53+
def test_variant_queryable_filter_1(variant_queryable):
54+
assert len(list(variant_queryable.filter(lambda v: v.alt == 'T'))) == 1
55+
56+
57+
def test_variant_queryable_filter_2(variant_queryable):
58+
assert len(list(variant_queryable.filter(lambda v: v.ref == 'A'))) == 2
59+
60+
61+
def test_variant_filter_range(variant_queryable):
62+
assert 2 == len(list(variant_queryable.filter_range(
63+
lambda variants, interval: (v.ref == 'A' for v in variants))))
64+
65+
66+
def test_VariantQueryable_filter_by_num_max(variant_queryable):
67+
assert 1 == len(list(variant_queryable.filter_range(
68+
NumberVariantQuery(max_num=1))))
69+
70+
71+
def test_VariantQueryable_filter_by_num_min(variant_queryable):
72+
assert 2 == len(list(variant_queryable.filter_range(
73+
NumberVariantQuery(min_num=2))))
74+
75+
76+
def test_VariantQueryable_filter_variant_query_2(variant_queryable):
77+
assert 2 == len(list(variant_queryable.filter(FilterVariantQuery())))
78+
79+
80+
def test_VariantQueryable_filter_variant_query_3(variant_queryable):
81+
assert 3 == len(list(variant_queryable.filter(
82+
FilterVariantQuery() | FilterVariantQuery(filter='q10'))))

tests/extractors/test_vcf_seq_extractor.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
import pytest
2+
from conftest import vcf_file, sample_5kb_fasta_file
23
from cyvcf2 import VCF
34
from pyfaidx import Sequence
45
from kipoiseq.extractors.vcf_seq import IntervalSeqBuilder
56
from kipoiseq.dataclasses import Variant, Interval
67
from kipoiseq.extractors import *
78

8-
fasta_file = 'tests/data/sample.5kb.fa'
9-
vcf_file = 'tests/data/test.vcf.gz'
9+
fasta_file = sample_5kb_fasta_file
10+
11+
intervals = [
12+
Interval('chr1', 4, 10),
13+
Interval('chr1', 5, 30),
14+
Interval('chr1', 20, 30)
15+
]
1016

1117

1218
@pytest.fixture
1319
def multi_sample_vcf():
1420
return MultiSampleVCF(vcf_file)
1521

1622

17-
def test_multi_sample_vcf_fetch_variant(multi_sample_vcf):
23+
def test_MultiSampleVCF_fetch_variant(multi_sample_vcf):
1824
interval = Interval('chr1', 3, 5)
1925
assert len(list(multi_sample_vcf.fetch_variants(interval))) == 2
2026
assert len(list(multi_sample_vcf.fetch_variants(interval, 'NA00003'))) == 1
@@ -25,6 +31,28 @@ def test_multi_sample_vcf_fetch_variant(multi_sample_vcf):
2531
assert len(list(multi_sample_vcf.fetch_variants(interval, 'NA00003'))) == 0
2632

2733

34+
def test_MultiSampleVCF_query_variants(multi_sample_vcf):
35+
vq = multi_sample_vcf.query_variants(intervals)
36+
variants = list(vq)
37+
assert len(variants) == 5
38+
assert variants[0].pos == 4
39+
assert variants[1].pos == 5
40+
41+
42+
def test_MultiSampleVCF_get_samples(multi_sample_vcf):
43+
variants = list(multi_sample_vcf)
44+
samples = multi_sample_vcf.get_samples(variants[0])
45+
assert samples == {'NA00003': 3}
46+
47+
48+
def test_MultiSampleVCF_get_variant(multi_sample_vcf):
49+
variant = multi_sample_vcf.get_variant("chr1:4:T>C")
50+
assert variant.chrom == 'chr1'
51+
assert variant.pos == 4
52+
assert variant.ref == 'T'
53+
assert variant.alt == 'C'
54+
55+
2856
@pytest.fixture
2957
def interval_seq_builder():
3058
return IntervalSeqBuilder([

0 commit comments

Comments
 (0)