Skip to content

Commit fa49b32

Browse files
authored
Merge pull request #23 from MuhammedHasan/master
Add vcf_seq extractor
2 parents f92f3aa + ee89c62 commit fa49b32

File tree

7 files changed

+431
-1
lines changed

7 files changed

+431
-1
lines changed

kipoiseq/extractors/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .base import *
2+
from .vcf_seq import *
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def extract(self, interval):
4747
rc = self.use_strand and interval.strand == "-"
4848

4949
# pyfaidx wants a 1-based interval
50-
seq = str(self.fasta.get_seq(interval.chrom, interval.start + 1, interval.stop, rc=rc).seq)
50+
seq = str(self.fasta.get_seq(interval.chrom,
51+
interval.start + 1, interval.stop, rc=rc).seq)
5152

5253
# optionally, force upper-case letters
5354
if self.force_upper:

kipoiseq/extractors/vcf_seq.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
from pybedtools import Interval
2+
from pyfaidx import Sequence, complement
3+
from kipoiseq.extractors import BaseExtractor, FastaStringExtractor
4+
try:
5+
from cyvcf2 import VCF
6+
except ImportError:
7+
VCF = object
8+
9+
__all__ = [
10+
'VariantSeqExtractor',
11+
'MultiSampleVCF',
12+
'SingleVariantVCFSeqExtractor',
13+
'SingleSeqVCFSeqExtractor'
14+
]
15+
16+
17+
class MultiSampleVCF(VCF):
18+
19+
def __init__(self, *args, **kwargs):
20+
from cyvcf2 import VCF
21+
super(MultiSampleVCF, self).__init__(*args, **kwargs)
22+
self.sample_mapping = dict(zip(self.samples, range(len(self.samples))))
23+
24+
def fetch_variants(self, interval, sample_id=None):
25+
for v in self(self._region(interval)):
26+
if sample_id is None or self._has_variant(v, sample_id):
27+
yield v
28+
29+
def _region(self, interval):
30+
return '%s:%d-%d' % (interval.chrom, interval.start, interval.end)
31+
32+
def _has_variant(self, variant, sample_id):
33+
return variant.gt_types[self.sample_mapping[sample_id]] != 0
34+
35+
36+
class IntervalSeqBuilder(list):
37+
"""
38+
String builder for `pyfaidx.Sequence` and `Interval` objects.
39+
"""
40+
41+
def restore(self, sequence):
42+
"""
43+
Args:
44+
seq: `pyfaidx.Sequence` which convert all interval inside
45+
to `Seqeunce` objects.
46+
"""
47+
for i, interval in enumerate(self):
48+
# interval.end can be bigger than interval.start
49+
interval_len = max(0, interval.end - interval.start)
50+
51+
if type(self[i]) == Interval:
52+
start = interval.start - sequence.start
53+
end = start + interval_len
54+
self[i] = sequence[start: end]
55+
56+
def _concat(self):
57+
for sequence in self:
58+
if type(sequence) != Sequence:
59+
raise TypeError('Intervals should be restored with `restore`'
60+
' method before calling concat method!')
61+
yield sequence.seq
62+
63+
def concat(self):
64+
"""
65+
Build the string from sequence objects.
66+
67+
Returns:
68+
str: the final sequence.
69+
"""
70+
return ''.join(self._concat())
71+
72+
73+
class VariantSeqExtractor(BaseExtractor):
74+
75+
def __init__(self, fasta_file):
76+
"""
77+
Args:
78+
fasta_file: path to the fasta file (can be gzipped)
79+
"""
80+
self.fasta = FastaStringExtractor(fasta_file, use_strand=True)
81+
82+
def extract(self, interval, variants, anchor, fixed_len=True):
83+
"""
84+
85+
Args:
86+
interval: pybedtools.Interval Region of interest from
87+
which to query the sequence. 0-based
88+
variants List[cyvcf2.Variant]: variants overlapping the `interval`.
89+
can also be indels. 1-based
90+
anchor: position w.r.t. the interval start. (0-based). E.g.
91+
for an interval of `chr1:10-20` the anchor of 0 denotes
92+
the point chr1:10 in the 0-based coordinate system. Similarly,
93+
`anchor=5` means the anchor point is right in the middle
94+
of the sequence e.g. first half of the sequence (5nt) will be
95+
upstream of the anchor and the second half (5nt) will be
96+
downstream of the anchor.
97+
fixed_len: if True, the return sequence will have the same length
98+
as the `interval` (e.g. `interval.end - interval.start`)
99+
100+
Returns:
101+
A single sequence (`str`) with all the variants applied.
102+
"""
103+
# Preprocessing
104+
anchor = max(min(anchor, interval.end), interval.start)
105+
variant_pairs = self._variant_to_sequence(variants)
106+
107+
# 1. Split variants overlapping with anchor
108+
variant_pairs = list(self._split_overlapping(variant_pairs, anchor))
109+
110+
# 2. split the variants into upstream and downstream
111+
# and sort the variants in each interval
112+
upstream_variants = sorted(
113+
filter(lambda x: x[0].start >= anchor, variant_pairs),
114+
key=lambda x: x[0].start)
115+
116+
downstream_variants = sorted(
117+
filter(lambda x: x[0].start < anchor, variant_pairs),
118+
key=lambda x: x[0].start, reverse=True)
119+
120+
# 3. Extend start and end position for deletions
121+
if fixed_len:
122+
istart, iend = self._updated_interval(
123+
interval, upstream_variants, downstream_variants)
124+
else:
125+
istart, iend = interval.start, interval.end
126+
127+
# 4. Iterate from the anchor point outwards. At each
128+
# register the interval from which to take the reference sequence
129+
# as well as the interval for the variant
130+
down_sb = self._downstream_builder(
131+
downstream_variants, interval, anchor, istart)
132+
133+
up_sb = self._upstream_builder(
134+
upstream_variants, interval, anchor, iend)
135+
136+
# 5. fetch the sequence and restore intervals in builder
137+
seq = self._fetch(interval, istart, iend)
138+
up_sb.restore(seq)
139+
down_sb.restore(seq)
140+
141+
# 6. Concate sequences from the upstream and downstream splits. Concat
142+
# upstream and downstream sequence. Cut to fix the length.
143+
down_str = down_sb.concat()
144+
up_str = up_sb.concat()
145+
146+
if fixed_len:
147+
down_str, up_str = self._cut_to_fix_len(
148+
down_str, up_str, interval, anchor)
149+
150+
seq = down_str + up_str
151+
152+
if interval.strand == '-':
153+
seq = complement(seq)[::-1]
154+
155+
return seq
156+
157+
def _variant_to_sequence(self, variants):
158+
"""
159+
Convert `cyvcf2.Variant` objects to `pyfaidx.Seqeunce` objects
160+
for reference and variants.
161+
"""
162+
for v in variants:
163+
ref = Sequence(name=v.CHROM, seq=v.REF,
164+
start=v.POS, end=v.POS + len(v.REF))
165+
# TO DO: consider alternative alleles.
166+
alt = Sequence(name=v.CHROM, seq=v.ALT[0],
167+
start=v.POS, end=v.POS + len(v.ALT[0]))
168+
yield ref, alt
169+
170+
def _split_overlapping(self, variant_pairs, anchor):
171+
"""
172+
Split the variants hitting the anchor into two
173+
"""
174+
for ref, alt in variant_pairs:
175+
if ref.start < anchor < ref.end or alt.start < anchor < alt.end:
176+
mid = anchor - ref.start
177+
yield ref[:mid], alt[:mid]
178+
yield ref[mid:], alt[mid:]
179+
else:
180+
yield ref, alt
181+
182+
def _updated_interval(self, interval, up_variants, down_variants):
183+
istart = interval.start
184+
iend = interval.end
185+
186+
for ref, alt in up_variants:
187+
diff_len = len(alt) - len(ref)
188+
if diff_len < 0:
189+
iend -= diff_len
190+
191+
for ref, alt in down_variants:
192+
diff_len = len(alt) - len(ref)
193+
if diff_len < 0:
194+
istart += diff_len
195+
196+
return istart, iend
197+
198+
def _downstream_builder(self, down_variants, interval, anchor, istart):
199+
down_sb = IntervalSeqBuilder()
200+
201+
prev = anchor
202+
for ref, alt in down_variants:
203+
if ref.end <= istart:
204+
break
205+
down_sb.append(Interval(interval.chrom, ref.end, prev))
206+
down_sb.append(alt)
207+
prev = ref.start
208+
down_sb.append(Interval(interval.chrom, istart, prev))
209+
down_sb.reverse()
210+
211+
return down_sb
212+
213+
def _upstream_builder(self, up_variants, interval, anchor, iend):
214+
up_sb = IntervalSeqBuilder()
215+
216+
prev = anchor
217+
for ref, alt in up_variants:
218+
if ref.start > iend:
219+
break
220+
up_sb.append(Interval(interval.chrom, prev, ref.start))
221+
up_sb.append(alt)
222+
prev = ref.end
223+
up_sb.append(Interval(interval.chrom, prev, iend))
224+
225+
return up_sb
226+
227+
def _fetch(self, interval, istart, iend):
228+
seq = self.fasta.extract(Interval(interval.chrom, istart, iend))
229+
seq = Sequence(name=interval.chrom, seq=seq, start=istart, end=iend)
230+
return seq
231+
232+
def _cut_to_fix_len(self, down_str, up_str, interval, anchor):
233+
down_len = anchor - interval.start
234+
up_len = interval.end - anchor
235+
down_str = down_str[-down_len:] if down_len else ''
236+
up_str = up_str[:up_len] if up_len else ''
237+
return down_str, up_str
238+
239+
240+
class BaseVCFSeqExtractor(BaseExtractor):
241+
def __init__(self, fasta_file, vcf_file):
242+
self.fasta_file = fasta_file
243+
self.vcf_file = vcf_file
244+
self.variant_extractor = VariantSeqExtractor(fasta_file)
245+
self.vcf = MultiSampleVCF(vcf_file)
246+
247+
248+
class SingleVariantVCFSeqExtractor(BaseVCFSeqExtractor):
249+
250+
def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
251+
for variant in self.vcf.fetch_variants(interval, sample_id):
252+
yield self.variant_extractor.extract(interval,
253+
variants=[variant],
254+
anchor=anchor,
255+
fixed_len=fixed_len)
256+
257+
258+
class SingleSeqVCFSeqExtractor(BaseVCFSeqExtractor):
259+
260+
def extract(self, interval, anchor=None, sample_id=None, fixed_len=True):
261+
return self.variant_extractor.extract(
262+
interval, variants=self.vcf.fetch_variants(interval, sample_id),
263+
anchor=anchor, fixed_len=fixed_len)

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"coveralls",
3333
"scikit-learn",
3434
"cython",
35+
"cyvcf2",
3536
# "genomelake",
3637
"keras",
3738
"tensorflow",

tests/data/test.vcf.gz

568 Bytes
Binary file not shown.

tests/data/test.vcf.gz.tbi

106 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)