|
| 1 | +from pybedtools import Interval |
| 2 | +from pyfaidx import Sequence, complement |
| 3 | +from kipoiseq.extractors import BaseExtractor, FastaStringExtractor |
| 4 | +try: |
| 5 | + from cyvcf2 import VCF |
| 6 | +except ImportError: |
| 7 | + VCF = object |
| 8 | + |
| 9 | +__all__ = [ |
| 10 | + 'VariantSeqExtractor', |
| 11 | + 'MultiSampleVCF', |
| 12 | + 'SingleVariantVCFSeqExtractor', |
| 13 | + 'SingleSeqVCFSeqExtractor' |
| 14 | +] |
| 15 | + |
| 16 | + |
| 17 | +class MultiSampleVCF(VCF): |
| 18 | + |
| 19 | + def __init__(self, *args, **kwargs): |
| 20 | + from cyvcf2 import VCF |
| 21 | + super(MultiSampleVCF, self).__init__(*args, **kwargs) |
| 22 | + self.sample_mapping = dict(zip(self.samples, range(len(self.samples)))) |
| 23 | + |
| 24 | + def fetch_variants(self, interval, sample_id=None): |
| 25 | + for v in self(self._region(interval)): |
| 26 | + if sample_id is None or self._has_variant(v, sample_id): |
| 27 | + yield v |
| 28 | + |
| 29 | + def _region(self, interval): |
| 30 | + return '%s:%d-%d' % (interval.chrom, interval.start, interval.end) |
| 31 | + |
| 32 | + def _has_variant(self, variant, sample_id): |
| 33 | + return variant.gt_types[self.sample_mapping[sample_id]] != 0 |
| 34 | + |
| 35 | + |
| 36 | +class IntervalSeqBuilder(list): |
| 37 | + """ |
| 38 | + String builder for `pyfaidx.Sequence` and `Interval` objects. |
| 39 | + """ |
| 40 | + |
| 41 | + def restore(self, sequence): |
| 42 | + """ |
| 43 | + Args: |
| 44 | + seq: `pyfaidx.Sequence` which convert all interval inside |
| 45 | + to `Seqeunce` objects. |
| 46 | + """ |
| 47 | + for i, interval in enumerate(self): |
| 48 | + # interval.end can be bigger than interval.start |
| 49 | + interval_len = max(0, interval.end - interval.start) |
| 50 | + |
| 51 | + if type(self[i]) == Interval: |
| 52 | + start = interval.start - sequence.start |
| 53 | + end = start + interval_len |
| 54 | + self[i] = sequence[start: end] |
| 55 | + |
| 56 | + def _concat(self): |
| 57 | + for sequence in self: |
| 58 | + if type(sequence) != Sequence: |
| 59 | + raise TypeError('Intervals should be restored with `restore`' |
| 60 | + ' method before calling concat method!') |
| 61 | + yield sequence.seq |
| 62 | + |
| 63 | + def concat(self): |
| 64 | + """ |
| 65 | + Build the string from sequence objects. |
| 66 | +
|
| 67 | + Returns: |
| 68 | + str: the final sequence. |
| 69 | + """ |
| 70 | + return ''.join(self._concat()) |
| 71 | + |
| 72 | + |
| 73 | +class VariantSeqExtractor(BaseExtractor): |
| 74 | + |
| 75 | + def __init__(self, fasta_file): |
| 76 | + """ |
| 77 | + Args: |
| 78 | + fasta_file: path to the fasta file (can be gzipped) |
| 79 | + """ |
| 80 | + self.fasta = FastaStringExtractor(fasta_file, use_strand=True) |
| 81 | + |
| 82 | + def extract(self, interval, variants, anchor, fixed_len=True): |
| 83 | + """ |
| 84 | +
|
| 85 | + Args: |
| 86 | + interval: pybedtools.Interval Region of interest from |
| 87 | + which to query the sequence. 0-based |
| 88 | + variants List[cyvcf2.Variant]: variants overlapping the `interval`. |
| 89 | + can also be indels. 1-based |
| 90 | + anchor: position w.r.t. the interval start. (0-based). E.g. |
| 91 | + for an interval of `chr1:10-20` the anchor of 0 denotes |
| 92 | + the point chr1:10 in the 0-based coordinate system. Similarly, |
| 93 | + `anchor=5` means the anchor point is right in the middle |
| 94 | + of the sequence e.g. first half of the sequence (5nt) will be |
| 95 | + upstream of the anchor and the second half (5nt) will be |
| 96 | + downstream of the anchor. |
| 97 | + fixed_len: if True, the return sequence will have the same length |
| 98 | + as the `interval` (e.g. `interval.end - interval.start`) |
| 99 | +
|
| 100 | + Returns: |
| 101 | + A single sequence (`str`) with all the variants applied. |
| 102 | + """ |
| 103 | + # Preprocessing |
| 104 | + anchor = max(min(anchor, interval.end), interval.start) |
| 105 | + variant_pairs = self._variant_to_sequence(variants) |
| 106 | + |
| 107 | + # 1. Split variants overlapping with anchor |
| 108 | + variant_pairs = list(self._split_overlapping(variant_pairs, anchor)) |
| 109 | + |
| 110 | + # 2. split the variants into upstream and downstream |
| 111 | + # and sort the variants in each interval |
| 112 | + upstream_variants = sorted( |
| 113 | + filter(lambda x: x[0].start >= anchor, variant_pairs), |
| 114 | + key=lambda x: x[0].start) |
| 115 | + |
| 116 | + downstream_variants = sorted( |
| 117 | + filter(lambda x: x[0].start < anchor, variant_pairs), |
| 118 | + key=lambda x: x[0].start, reverse=True) |
| 119 | + |
| 120 | + # 3. Extend start and end position for deletions |
| 121 | + if fixed_len: |
| 122 | + istart, iend = self._updated_interval( |
| 123 | + interval, upstream_variants, downstream_variants) |
| 124 | + else: |
| 125 | + istart, iend = interval.start, interval.end |
| 126 | + |
| 127 | + # 4. Iterate from the anchor point outwards. At each |
| 128 | + # register the interval from which to take the reference sequence |
| 129 | + # as well as the interval for the variant |
| 130 | + down_sb = self._downstream_builder( |
| 131 | + downstream_variants, interval, anchor, istart) |
| 132 | + |
| 133 | + up_sb = self._upstream_builder( |
| 134 | + upstream_variants, interval, anchor, iend) |
| 135 | + |
| 136 | + # 5. fetch the sequence and restore intervals in builder |
| 137 | + seq = self._fetch(interval, istart, iend) |
| 138 | + up_sb.restore(seq) |
| 139 | + down_sb.restore(seq) |
| 140 | + |
| 141 | + # 6. Concate sequences from the upstream and downstream splits. Concat |
| 142 | + # upstream and downstream sequence. Cut to fix the length. |
| 143 | + down_str = down_sb.concat() |
| 144 | + up_str = up_sb.concat() |
| 145 | + |
| 146 | + if fixed_len: |
| 147 | + down_str, up_str = self._cut_to_fix_len( |
| 148 | + down_str, up_str, interval, anchor) |
| 149 | + |
| 150 | + seq = down_str + up_str |
| 151 | + |
| 152 | + if interval.strand == '-': |
| 153 | + seq = complement(seq)[::-1] |
| 154 | + |
| 155 | + return seq |
| 156 | + |
| 157 | + def _variant_to_sequence(self, variants): |
| 158 | + """ |
| 159 | + Convert `cyvcf2.Variant` objects to `pyfaidx.Seqeunce` objects |
| 160 | + for reference and variants. |
| 161 | + """ |
| 162 | + for v in variants: |
| 163 | + ref = Sequence(name=v.CHROM, seq=v.REF, |
| 164 | + start=v.POS, end=v.POS + len(v.REF)) |
| 165 | + # TO DO: consider alternative alleles. |
| 166 | + alt = Sequence(name=v.CHROM, seq=v.ALT[0], |
| 167 | + start=v.POS, end=v.POS + len(v.ALT[0])) |
| 168 | + yield ref, alt |
| 169 | + |
| 170 | + def _split_overlapping(self, variant_pairs, anchor): |
| 171 | + """ |
| 172 | + Split the variants hitting the anchor into two |
| 173 | + """ |
| 174 | + for ref, alt in variant_pairs: |
| 175 | + if ref.start < anchor < ref.end or alt.start < anchor < alt.end: |
| 176 | + mid = anchor - ref.start |
| 177 | + yield ref[:mid], alt[:mid] |
| 178 | + yield ref[mid:], alt[mid:] |
| 179 | + else: |
| 180 | + yield ref, alt |
| 181 | + |
| 182 | + def _updated_interval(self, interval, up_variants, down_variants): |
| 183 | + istart = interval.start |
| 184 | + iend = interval.end |
| 185 | + |
| 186 | + for ref, alt in up_variants: |
| 187 | + diff_len = len(alt) - len(ref) |
| 188 | + if diff_len < 0: |
| 189 | + iend -= diff_len |
| 190 | + |
| 191 | + for ref, alt in down_variants: |
| 192 | + diff_len = len(alt) - len(ref) |
| 193 | + if diff_len < 0: |
| 194 | + istart += diff_len |
| 195 | + |
| 196 | + return istart, iend |
| 197 | + |
| 198 | + def _downstream_builder(self, down_variants, interval, anchor, istart): |
| 199 | + down_sb = IntervalSeqBuilder() |
| 200 | + |
| 201 | + prev = anchor |
| 202 | + for ref, alt in down_variants: |
| 203 | + if ref.end <= istart: |
| 204 | + break |
| 205 | + down_sb.append(Interval(interval.chrom, ref.end, prev)) |
| 206 | + down_sb.append(alt) |
| 207 | + prev = ref.start |
| 208 | + down_sb.append(Interval(interval.chrom, istart, prev)) |
| 209 | + down_sb.reverse() |
| 210 | + |
| 211 | + return down_sb |
| 212 | + |
| 213 | + def _upstream_builder(self, up_variants, interval, anchor, iend): |
| 214 | + up_sb = IntervalSeqBuilder() |
| 215 | + |
| 216 | + prev = anchor |
| 217 | + for ref, alt in up_variants: |
| 218 | + if ref.start > iend: |
| 219 | + break |
| 220 | + up_sb.append(Interval(interval.chrom, prev, ref.start)) |
| 221 | + up_sb.append(alt) |
| 222 | + prev = ref.end |
| 223 | + up_sb.append(Interval(interval.chrom, prev, iend)) |
| 224 | + |
| 225 | + return up_sb |
| 226 | + |
| 227 | + def _fetch(self, interval, istart, iend): |
| 228 | + seq = self.fasta.extract(Interval(interval.chrom, istart, iend)) |
| 229 | + seq = Sequence(name=interval.chrom, seq=seq, start=istart, end=iend) |
| 230 | + return seq |
| 231 | + |
| 232 | + def _cut_to_fix_len(self, down_str, up_str, interval, anchor): |
| 233 | + down_len = anchor - interval.start |
| 234 | + up_len = interval.end - anchor |
| 235 | + down_str = down_str[-down_len:] if down_len else '' |
| 236 | + up_str = up_str[:up_len] if up_len else '' |
| 237 | + return down_str, up_str |
| 238 | + |
| 239 | + |
| 240 | +class BaseVCFSeqExtractor(BaseExtractor): |
| 241 | + def __init__(self, fasta_file, vcf_file): |
| 242 | + self.fasta_file = fasta_file |
| 243 | + self.vcf_file = vcf_file |
| 244 | + self.variant_extractor = VariantSeqExtractor(fasta_file) |
| 245 | + self.vcf = MultiSampleVCF(vcf_file) |
| 246 | + |
| 247 | + |
| 248 | +class SingleVariantVCFSeqExtractor(BaseVCFSeqExtractor): |
| 249 | + |
| 250 | + def extract(self, interval, anchor=None, sample_id=None, fixed_len=True): |
| 251 | + for variant in self.vcf.fetch_variants(interval, sample_id): |
| 252 | + yield self.variant_extractor.extract(interval, |
| 253 | + variants=[variant], |
| 254 | + anchor=anchor, |
| 255 | + fixed_len=fixed_len) |
| 256 | + |
| 257 | + |
| 258 | +class SingleSeqVCFSeqExtractor(BaseVCFSeqExtractor): |
| 259 | + |
| 260 | + def extract(self, interval, anchor=None, sample_id=None, fixed_len=True): |
| 261 | + return self.variant_extractor.extract( |
| 262 | + interval, variants=self.vcf.fetch_variants(interval, sample_id), |
| 263 | + anchor=anchor, fixed_len=fixed_len) |
0 commit comments