Skip to content

Commit e234baf

Browse files
MuhammedHasanM. Hasan Celik
andauthored
variant combinator (#94)
* variant combinator * variant combinator vcf * bug fix upper case * sample from vcf * format fields bug fix * sort_intervals * update on testcase Co-authored-by: M. Hasan Celik <[email protected]>
1 parent 8893483 commit e234baf

File tree

8 files changed

+263
-11
lines changed

8 files changed

+263
-11
lines changed

kipoiseq/extractors/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
from .vcf_matching import *
77
from .multi_interval import *
88
from .protein import *
9+
from .variant_combinations import *
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from typing import Iterable
2+
from itertools import product
3+
from kipoiseq import Interval, Variant
4+
from kipoiseq.utils import alphabets
5+
from kipoiseq.extractors import FastaStringExtractor
6+
from kipoiseq.extractors.vcf_matching import pyranges_to_intervals
7+
8+
9+
class VariantCombinator:
10+
11+
def __init__(self, fasta_file: str, bed_file: str = None,
12+
variant_type='snv', alphabet='DNA'):
13+
if variant_type not in {'all', 'snv', 'in', 'del'}:
14+
raise ValueError("variant_type should be one of "
15+
"{'all', 'snv', 'in', 'del'}")
16+
17+
self.bed_file = bed_file
18+
self.fasta = fasta_file
19+
self.fasta = FastaStringExtractor(fasta_file, force_upper=True)
20+
self.variant_type = variant_type
21+
self.alphabet = alphabets[alphabet]
22+
23+
def combination_variants_snv(self, interval: Interval) -> Iterable[Variant]:
24+
"""Returns all the possible variants in the regions.
25+
26+
interval: interval of variants
27+
"""
28+
seq = self.fasta.extract(interval)
29+
for pos, ref in zip(range(interval.start, interval.end), seq):
30+
pos = pos + 1 # 0 to 1 base
31+
for alt in self.alphabet:
32+
if ref != alt:
33+
yield Variant(interval.chrom, pos, ref, alt)
34+
35+
def combination_variants_insertion(self, interval, length=2) -> Iterable[Variant]:
36+
"""Returns all the possible variants in the regions.
37+
38+
interval: interval of variants
39+
length: insertions up to length
40+
"""
41+
if length < 2:
42+
raise ValueError('length argument should be larger than 1')
43+
44+
seq = self.fasta.extract(interval)
45+
for pos, ref in zip(range(interval.start, interval.end), seq):
46+
pos = pos + 1 # 0 to 1 base
47+
for l in range(2, length + 1):
48+
for alt in product(self.alphabet, repeat=l):
49+
yield Variant(interval.chrom, pos, ref, ''.join(alt))
50+
51+
def combination_variants_deletion(self, interval, length=1) -> Iterable[Variant]:
52+
"""Returns all the possible variants in the regions.
53+
interval: interval of variants
54+
length: deletions up to length
55+
"""
56+
if length < 1 and length <= interval.width:
57+
raise ValueError('length argument should be larger than 0'
58+
' and smaller than interval witdh')
59+
60+
seq = self.fasta.extract(interval)
61+
for i, pos in enumerate(range(interval.start, interval.end)):
62+
pos = pos + 1 # 0 to 1 base
63+
for j in range(1, length + 1):
64+
if i + j <= len(seq):
65+
yield Variant(interval.chrom, pos, seq[i:i + j], '')
66+
67+
def combination_variants(self, interval, variant_type='snv',
68+
in_length=2, del_length=2) -> Iterable[Variant]:
69+
if variant_type in {'snv', 'all'}:
70+
yield from self.combination_variants_snv(interval)
71+
if variant_type in {'indel', 'in', 'all'}:
72+
yield from self.combination_variants_insertion(
73+
interval, length=in_length)
74+
if variant_type in {'indel', 'del', 'all'}:
75+
yield from self.combination_variants_deletion(
76+
interval, length=del_length)
77+
78+
def __iter__(self) -> Iterable[Variant]:
79+
import pyranges as pr
80+
81+
gr = pr.read_bed(self.bed_file)
82+
gr = gr.merge(strand=False).sort()
83+
84+
for interval in pyranges_to_intervals(gr):
85+
yield from self.combination_variants(interval, self.variant_type)
86+
87+
def to_vcf(self, path):
88+
from cyvcf2 import Writer
89+
header = '''##fileformat=VCFv4.2
90+
#CHROM POS ID REF ALT QUAL FILTER INFO
91+
'''
92+
writer = Writer.from_string(path, header)
93+
94+
for v in self:
95+
variant = writer.variant_from_string('\t'.join([
96+
v.chrom, str(v.pos), '.', v.ref, v.alt, '.', '.', '.'
97+
]))
98+
writer.write_record(variant)

kipoiseq/extractors/vcf_query.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import csv
12
import abc
23
from itertools import islice
34
from typing import Tuple, Iterable, List
@@ -241,3 +242,40 @@ def to_vcf(self, path, remove_samples=False, clean_info=False):
241242
variant = writer.variant_from_string('\t'.join(variant))
242243

243244
writer.write_record(variant)
245+
246+
def to_sample_csv(self, path, format_fields=None):
247+
"""
248+
Extract samples and FORMAT from vcf then save as csv file.
249+
"""
250+
format_fields = format_fields or list()
251+
writer = None
252+
253+
with open(path, 'w') as f:
254+
255+
for variant in self:
256+
variant_fields = str(variant.source).strip().split('\t')
257+
258+
if writer is None:
259+
# FORMAT field
260+
fieldnames = ['variant', 'sample',
261+
'genotype'] + format_fields
262+
format_fields = set(format_fields)
263+
writer = csv.DictWriter(f, fieldnames=fieldnames)
264+
writer.writeheader()
265+
samples = self.vcf.samples
266+
267+
values = dict(zip(samples, map(
268+
lambda x: x.split(':'), variant_fields[9:])))
269+
fields = variant_fields[8].split(':')
270+
271+
for sample, gt in self.vcf.get_samples(variant).items():
272+
row = dict()
273+
row['variant'] = str(variant)
274+
row['sample'] = sample
275+
row['genotype'] = gt
276+
277+
for k, v in zip(fields, values[sample]):
278+
if k in format_fields:
279+
row[k] = v
280+
281+
writer.writerow(row)

kipoiseq/transforms/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def tokenize(seq, alphabet=DNA, neutral_alphabet=["N"]):
8585
neutral_alphabet = [neutral_alphabet]
8686

8787
nchar = len(alphabet[0])
88-
for l in alphabet + neutral_alphabet:
88+
for l in (*alphabet, *neutral_alphabet):
8989
assert len(l) == nchar
9090
assert len(seq) % nchar == 0 # since we are using striding
9191

kipoiseq/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,11 @@
44
import numpy as np
55
from six import string_types
66

7-
# alphabets:
8-
from kipoiseq import Variant
97

10-
DNA = ["A", "C", "G", "T"]
11-
RNA = ["A", "C", "G", "U"]
12-
AMINO_ACIDS = ["A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
13-
"I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V"]
8+
DNA = ("A", "C", "G", "T")
9+
RNA = ("A", "C", "G", "U")
10+
AMINO_ACIDS = ("A", "R", "N", "D", "B", "C", "E", "Q", "Z", "G", "H",
11+
"I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V")
1412

1513
alphabets = {"DNA": DNA,
1614
"RNA": RNA,
@@ -38,7 +36,8 @@ def parse_dtype(dtype):
3836
try:
3937
return eval(dtype)
4038
except Exception as e:
41-
raise ValueError("Unable to parse dtype: {}. \nException: {}".format(dtype, e))
39+
raise ValueError(
40+
"Unable to parse dtype: {}. \nException: {}".format(dtype, e))
4241
else:
4342
return dtype
4443

tests/dataloaders/test_sequence.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ def fasta_file():
1515
def intervals_file():
1616
return "tests/data/sample_intervals.bed"
1717

18+
1819
@pytest.fixture
1920
def intervals_file_strand():
2021
return "tests/data/sample_interval_strand.bed"
@@ -25,8 +26,9 @@ def intervals_file_strand():
2526

2627
def test_min_props():
2728
# minimal set of properties that need to be specified on the object
28-
min_set_props = ["output_schema", "type", "defined_as", "info", "args", "dependencies", "postprocessing",
29-
"source", "source_dir"]
29+
min_set_props = ["output_schema", "type", "defined_as", "info", "args",
30+
"dependencies", "source", "source_dir"]
31+
# TODO: "postprocessing" is this part of min_set_props?
3032

3133
for Dl in [StringSeqIntervalDl, SeqIntervalDl]:
3234
props = dir(Dl)
@@ -56,11 +58,14 @@ def test_fasta_based_dataset(intervals_file, fasta_file):
5658
vals = dl.load_all()
5759
assert vals['inputs'][0] == 'GT'
5860

61+
5962
def test_use_strand(intervals_file_strand, fasta_file):
60-
dl = StringSeqIntervalDl(intervals_file_strand, fasta_file, use_strand=True)
63+
dl = StringSeqIntervalDl(intervals_file_strand,
64+
fasta_file, use_strand=True)
6165
vals = dl.load_all()
6266
assert vals['inputs'][0] == 'AC'
6367

68+
6469
def test_seq_dataset(intervals_file, fasta_file):
6570
dl = SeqIntervalDl(intervals_file, fasta_file)
6671
ret_val = dl[0]
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import pytest
2+
from conftest import example_intervals_bed, sample_5kb_fasta_file
3+
import pyranges as pr
4+
from kipoiseq import Interval
5+
from kipoiseq.extractors import VariantCombinator, MultiSampleVCF
6+
7+
8+
@pytest.fixture
9+
def variant_combinator():
10+
return VariantCombinator(sample_5kb_fasta_file, example_intervals_bed)
11+
12+
13+
def test_VariantCombinator_combination_variants(variant_combinator):
14+
interval = Interval('chr1', 20, 30)
15+
variants = list(variant_combinator.combination_variants(interval, 'snv'))
16+
assert len(variants) == 30
17+
18+
interval = Interval('chr1', 20, 22)
19+
variants = list(variant_combinator.combination_variants(interval, 'snv'))
20+
assert variants[0].chrom == 'chr1'
21+
assert variants[0].ref == 'A'
22+
assert variants[0].alt == 'C'
23+
assert variants[1].alt == 'G'
24+
assert variants[2].alt == 'T'
25+
26+
assert variants[3].ref == 'C'
27+
assert variants[3].alt == 'A'
28+
assert variants[4].alt == 'G'
29+
assert variants[5].alt == 'T'
30+
31+
interval = Interval('chr1', 20, 22)
32+
variants = list(variant_combinator.combination_variants(interval, 'in'))
33+
len(variants) == 32
34+
assert variants[0].ref == 'A'
35+
assert variants[0].alt == 'AA'
36+
assert variants[15].alt == 'TT'
37+
38+
assert variants[16].ref == 'C'
39+
assert variants[16].alt == 'AA'
40+
assert variants[31].alt == 'TT'
41+
42+
interval = Interval('chr1', 20, 22)
43+
variants = list(variant_combinator.combination_variants(
44+
interval, 'del', del_length=2))
45+
assert len(variants) == 3
46+
assert variants[0].ref == 'A'
47+
assert variants[0].alt == ''
48+
assert variants[1].ref == 'AC'
49+
assert variants[1].alt == ''
50+
assert variants[2].ref == 'C'
51+
assert variants[2].alt == ''
52+
53+
variants = list(variant_combinator.combination_variants(
54+
interval, 'all', in_length=2, del_length=2))
55+
assert len(variants) == 6 + 32 + 3
56+
57+
58+
def test_VariantCombinator_iter(variant_combinator):
59+
variants = list(variant_combinator)
60+
df = pr.read_bed(example_intervals_bed).merge(strand=False).df
61+
num_snv = (df['End'] - df['Start']).sum() * 3
62+
assert len(variants) == num_snv
63+
assert len(variants) == len(set(variants))
64+
65+
66+
def test_VariantCombinator_to_vcf(tmpdir, variant_combinator):
67+
output_vcf_file = str(tmpdir / 'output.vcf')
68+
variant_combinator.to_vcf(output_vcf_file)
69+
70+
vcf = MultiSampleVCF(output_vcf_file)
71+
72+
df = pr.read_bed(example_intervals_bed).merge(strand=False).df
73+
num_snv = (df['End'] - df['Start']).sum() * 3
74+
assert len(list(vcf)) == num_snv

tests/extractors/test_vcf_query.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pytest
22
from conftest import vcf_file
3+
import pandas as pd
34
from kipoiseq.dataclasses import Variant, Interval
45
from kipoiseq.extractors.vcf_seq import MultiSampleVCF
56
from kipoiseq.extractors.vcf_query import *
@@ -137,3 +138,39 @@ def test_VariantQueryable_to_vcf(tmp_path):
137138

138139
vcf = MultiSampleVCF(path)
139140
assert len(vcf.samples) == 0
141+
142+
143+
def test_VariantQueryable_to_sample_csv(tmp_path):
144+
vcf = MultiSampleVCF(vcf_file)
145+
146+
variant_queryable = vcf.query_all()
147+
148+
path = str(tmp_path / 'sample.csv')
149+
variant_queryable.to_sample_csv(path)
150+
151+
df = pd.read_csv(path)
152+
df_expected = pd.DataFrame({
153+
'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'],
154+
'sample': ['NA00003', 'NA00002'],
155+
'genotype': [3, 3]
156+
})
157+
pd.testing.assert_frame_equal(df, df_expected)
158+
159+
160+
def test_VariantQueryable_to_sample_csv_fields(tmp_path):
161+
vcf = MultiSampleVCF(vcf_file)
162+
163+
variant_queryable = vcf.query_all()
164+
165+
path = str(tmp_path / 'sample.csv')
166+
variant_queryable.to_sample_csv(path, ['GT', 'HQ'])
167+
168+
df = pd.read_csv(path)
169+
df_expected = pd.DataFrame({
170+
'variant': ['chr1:4:T>C', 'chr1:25:AACG>GA'],
171+
'sample': ['NA00003', 'NA00002'],
172+
'genotype': [3, 3],
173+
'GT': ['1/1', '1/1'],
174+
'HQ': ['51,51', '10,10']
175+
})
176+
pd.testing.assert_frame_equal(df, df_expected)

0 commit comments

Comments
 (0)