|
| 1 | +from typing import Iterable |
| 2 | +from itertools import product |
| 3 | +from kipoiseq import Interval, Variant |
| 4 | +from kipoiseq.utils import alphabets |
| 5 | +from kipoiseq.extractors import FastaStringExtractor |
| 6 | +from kipoiseq.extractors.vcf_matching import pyranges_to_intervals |
| 7 | + |
| 8 | + |
| 9 | +class VariantCombinator: |
| 10 | + |
| 11 | + def __init__(self, fasta_file: str, bed_file: str = None, |
| 12 | + variant_type='snv', alphabet='DNA'): |
| 13 | + if variant_type not in {'all', 'snv', 'in', 'del'}: |
| 14 | + raise ValueError("variant_type should be one of " |
| 15 | + "{'all', 'snv', 'in', 'del'}") |
| 16 | + |
| 17 | + self.bed_file = bed_file |
| 18 | + self.fasta = fasta_file |
| 19 | + self.fasta = FastaStringExtractor(fasta_file, force_upper=True) |
| 20 | + self.variant_type = variant_type |
| 21 | + self.alphabet = alphabets[alphabet] |
| 22 | + |
| 23 | + def combination_variants_snv(self, interval: Interval) -> Iterable[Variant]: |
| 24 | + """Returns all the possible variants in the regions. |
| 25 | +
|
| 26 | + interval: interval of variants |
| 27 | + """ |
| 28 | + seq = self.fasta.extract(interval) |
| 29 | + for pos, ref in zip(range(interval.start, interval.end), seq): |
| 30 | + pos = pos + 1 # 0 to 1 base |
| 31 | + for alt in self.alphabet: |
| 32 | + if ref != alt: |
| 33 | + yield Variant(interval.chrom, pos, ref, alt) |
| 34 | + |
| 35 | + def combination_variants_insertion(self, interval, length=2) -> Iterable[Variant]: |
| 36 | + """Returns all the possible variants in the regions. |
| 37 | +
|
| 38 | + interval: interval of variants |
| 39 | + length: insertions up to length |
| 40 | + """ |
| 41 | + if length < 2: |
| 42 | + raise ValueError('length argument should be larger than 1') |
| 43 | + |
| 44 | + seq = self.fasta.extract(interval) |
| 45 | + for pos, ref in zip(range(interval.start, interval.end), seq): |
| 46 | + pos = pos + 1 # 0 to 1 base |
| 47 | + for l in range(2, length + 1): |
| 48 | + for alt in product(self.alphabet, repeat=l): |
| 49 | + yield Variant(interval.chrom, pos, ref, ''.join(alt)) |
| 50 | + |
| 51 | + def combination_variants_deletion(self, interval, length=1) -> Iterable[Variant]: |
| 52 | + """Returns all the possible variants in the regions. |
| 53 | + interval: interval of variants |
| 54 | + length: deletions up to length |
| 55 | + """ |
| 56 | + if length < 1 and length <= interval.width: |
| 57 | + raise ValueError('length argument should be larger than 0' |
| 58 | + ' and smaller than interval witdh') |
| 59 | + |
| 60 | + seq = self.fasta.extract(interval) |
| 61 | + for i, pos in enumerate(range(interval.start, interval.end)): |
| 62 | + pos = pos + 1 # 0 to 1 base |
| 63 | + for j in range(1, length + 1): |
| 64 | + if i + j <= len(seq): |
| 65 | + yield Variant(interval.chrom, pos, seq[i:i + j], '') |
| 66 | + |
| 67 | + def combination_variants(self, interval, variant_type='snv', |
| 68 | + in_length=2, del_length=2) -> Iterable[Variant]: |
| 69 | + if variant_type in {'snv', 'all'}: |
| 70 | + yield from self.combination_variants_snv(interval) |
| 71 | + if variant_type in {'indel', 'in', 'all'}: |
| 72 | + yield from self.combination_variants_insertion( |
| 73 | + interval, length=in_length) |
| 74 | + if variant_type in {'indel', 'del', 'all'}: |
| 75 | + yield from self.combination_variants_deletion( |
| 76 | + interval, length=del_length) |
| 77 | + |
| 78 | + def __iter__(self) -> Iterable[Variant]: |
| 79 | + import pyranges as pr |
| 80 | + |
| 81 | + gr = pr.read_bed(self.bed_file) |
| 82 | + gr = gr.merge(strand=False).sort() |
| 83 | + |
| 84 | + for interval in pyranges_to_intervals(gr): |
| 85 | + yield from self.combination_variants(interval, self.variant_type) |
| 86 | + |
| 87 | + def to_vcf(self, path): |
| 88 | + from cyvcf2 import Writer |
| 89 | + header = '''##fileformat=VCFv4.2 |
| 90 | +#CHROM POS ID REF ALT QUAL FILTER INFO |
| 91 | +''' |
| 92 | + writer = Writer.from_string(path, header) |
| 93 | + |
| 94 | + for v in self: |
| 95 | + variant = writer.variant_from_string('\t'.join([ |
| 96 | + v.chrom, str(v.pos), '.', v.ref, v.alt, '.', '.', '.' |
| 97 | + ])) |
| 98 | + writer.write_record(variant) |
0 commit comments