1- from typing import List
1+ import os
2+ from typing import List , Union , Iterable , Iterator
23import pandas as pd
34from kipoiseq .dataclasses import Variant , Interval
4- from kipoiseq .extractors import MultiSampleVCF
5+ from kipoiseq .variant_source import VariantFetcher
56
67try :
78 from pyranges import PyRanges
@@ -82,14 +83,43 @@ def intervals_to_pyranges(intervals):
8283 )
8384
8485
86+ class PyrangesVariantFetcher (VariantFetcher ):
87+
88+ def __init__ (self , variants : List [Variant ]):
89+ self .variants = variants
90+ self ._variants_pr = None
91+
92+ @property
93+ def variants_pr (self ):
94+ # convert to PyRanges on demand
95+ if self ._variants_pr is None :
96+ self ._variants_pr = variants_to_pyranges (self .variants )
97+ return self ._variants_pr
98+
99+ def fetch_variants (self , interval : Union [Interval , Iterable [Interval ]]) -> Iterator [Variant ]:
100+ if isinstance (interval , Interval ):
101+ interval = [interval ]
102+ # convert interval(s) to PyRanges object
103+ interval_pr : PyRanges = intervals_to_pyranges (interval )
104+ # join with variants
105+ pr_join = interval_pr .join (self .variants_pr , suffix = '_variant' )
106+
107+ yield from pr_join .df ["variant" ]
108+
109+ def __iter__ (self ) -> Iterator [Variant ]:
110+ yield from self .variants
111+
112+
85113class BaseVariantMatcher :
86114 """
87115 Base variant intervals matcher
88116 """
89117
90118 def __init__ (
91119 self ,
92- vcf_file : str ,
120+ vcf_file : str = None ,
121+ variants : List [Variant ] = None ,
122+ variant_fetcher : VariantFetcher = None ,
93123 gtf_path : str = None ,
94124 bed_path : str = None ,
95125 pranges : PyRanges = None ,
@@ -101,7 +131,8 @@ def __init__(
101131 """
102132
103133 Args:
104- vcf_file: path of vcf file
134+ vcf_file: (optional) path of vcf file
135+ variants: (optional) readily processed variants
105136 gtf_path: (optional) path of gtf file contains features
106137 bed_path: (optional) path of bed file
107138 pranges: (optional) pyranges object
@@ -110,12 +141,31 @@ def __init__(
110141 pyranges object. This argument is not valid with intervals.
111142 Currently unused
112143 """
113- self .vcf = MultiSampleVCF (vcf_file , lazy = vcf_lazy )
144+ self .variant_fetcher = self . _read_variants (vcf_file , variants , variant_fetcher , vcf_lazy )
114145 self .interval_attrs = interval_attrs
115146 self .pr = self ._read_intervals (gtf_path , bed_path , pranges ,
116147 intervals , interval_attrs , duplicate_attr = True )
117148 self .variant_batch_size = variant_batch_size
118149
150+ @staticmethod
151+ def _read_variants (
152+ vcf_file = None ,
153+ variants = None ,
154+ variant_fetcher = None ,
155+ vcf_lazy : bool = True ,
156+ ) -> VariantFetcher :
157+ if vcf_file is not None :
158+ from kipoiseq .extractors import MultiSampleVCF
159+ return MultiSampleVCF (vcf_file , lazy = vcf_lazy )
160+ elif variant_fetcher is not None :
161+ assert isinstance (variant_fetcher , VariantFetcher ), \
162+ "Wrong type of variant fetcher: %s" % type (variant_fetcher )
163+ return variant_fetcher
164+ elif variants is not None :
165+ return PyrangesVariantFetcher (variants )
166+ else :
167+ raise ValueError ("No source of variants was specified!" )
168+
119169 @staticmethod
120170 def _read_intervals (gtf_path = None , bed_path = None , pranges = None ,
121171 intervals = None , interval_attrs = None , duplicate_attr = False ):
@@ -172,7 +222,7 @@ def _read_vcf_pyranges(self, batch_size=10000):
172222 Args:
173223 batch_size: size of each batch.
174224 """
175- for batch in self .vcf .batch_iter (batch_size ):
225+ for batch in self .variant_fetcher .batch_iter (batch_size ):
176226 yield variants_to_pyranges (batch )
177227
178228 def iter_pyranges (self ) -> PyRanges :
@@ -210,30 +260,14 @@ def __iter__(self) -> (Interval, Variant):
210260
211261class MultiVariantsMatcher (BaseVariantMatcher ):
212262
213- def __init__ (
214- self ,
215- vcf_file ,
216- gtf_path = None ,
217- pranges = None ,
218- intervals = None ,
219- interval_attrs = None ,
220- vcf_lazy = True ,
221- variant_batch_size = 10000
222- ):
223- super ().__init__ (
224- vcf_file ,
225- gtf_path = gtf_path ,
226- pranges = pranges ,
227- intervals = intervals ,
228- interval_attrs = interval_attrs ,
229- vcf_lazy = vcf_lazy ,
230- variant_batch_size = variant_batch_size
231- )
263+ def __init__ (self , * args , ** kwargs ):
264+ super ().__init__ (* args , ** kwargs )
265+
232266 if hasattr (self .pr , 'intervals' ):
233267 self .intervals = self .pr .intervals
234268 else :
235269 self .intervals = pyranges_to_intervals (self .pr )
236270
237271 def __iter__ (self ):
238272 for i in self .intervals :
239- yield i , self .vcf .fetch_variants (i )
273+ yield i , self .variant_fetcher .fetch_variants (i )
0 commit comments