1- from typing import List
1+ import os
2+ from typing import List , Union , Iterable , Iterator
23import pandas as pd
34from kipoiseq .dataclasses import Variant , Interval
4- from kipoiseq .extractors import MultiSampleVCF
5+ from kipoiseq .variant_source import VariantFetcher
56
67try :
78 from pyranges import PyRanges
@@ -82,14 +83,58 @@ def intervals_to_pyranges(intervals):
8283 )
8384
8485
86+ class PyrangesVariantFetcher (VariantFetcher ):
87+
88+ def __init__ (self , variants : List [Variant ]):
89+ self .variants = variants
90+ self ._variants_pr = None
91+
92+ @property
93+ def variants_pr (self ):
94+ # convert to PyRanges on demand
95+ if self ._variants_pr is None :
96+ self ._variants_pr = variants_to_pyranges (self .variants )
97+ return self ._variants_pr
98+
99+ def fetch_variants (self , interval : Union [Interval , Iterable [Interval ]]) -> Iterator [Variant ]:
100+ if isinstance (interval , Interval ):
101+ interval = [interval ]
102+ # convert interval(s) to PyRanges object
103+ interval_pr : PyRanges = intervals_to_pyranges (interval )
104+ # join with variants
105+ pr_join = interval_pr .join (self .variants_pr , suffix = '_variant' )
106+
107+ yield from pr_join .df ["variant" ]
108+
109+ def __iter__ (self ) -> Iterator [Variant ]:
110+ yield from self .variants
111+
112+
113+ class VariantFetcherProxy (VariantFetcher ):
114+
115+ def __init__ (self , variant_fetcher : VariantFetcher ):
116+ self .variant_fetcher = variant_fetcher
117+
118+ def fetch_variants (self , interval : Union [Interval , Iterable [Interval ]]) -> Iterator [Variant ]:
119+ yield from self .variant_fetcher .fetch_variants (interval )
120+
121+ def batch_iter (self , batch_size = 10000 ) -> Iterator [List [Variant ]]:
122+ yield from self .variant_fetcher .batch_iter (batch_size )
123+
124+ def __iter__ (self ) -> Iterator [Variant ]:
125+ yield from self .variant_fetcher
126+
127+
85128class BaseVariantMatcher :
86129 """
87130 Base variant intervals matcher
88131 """
89132
90133 def __init__ (
91134 self ,
92- vcf_file : str ,
135+ vcf_file : str = None ,
136+ variants : List [Variant ] = None ,
137+ variant_fetcher : VariantFetcher = None ,
93138 gtf_path : str = None ,
94139 bed_path : str = None ,
95140 pranges : PyRanges = None ,
@@ -101,7 +146,8 @@ def __init__(
101146 """
102147
103148 Args:
104- vcf_file: path of vcf file
149+ vcf_file: (optional) path of vcf file
150+ variants: (optional) readily processed variants
105151 gtf_path: (optional) path of gtf file contains features
106152 bed_path: (optional) path of bed file
107153 pranges: (optional) pyranges object
@@ -110,12 +156,37 @@ def __init__(
110156 pyranges object. This argument is not valid with intervals.
111157 Currently unused
112158 """
113- self .vcf = MultiSampleVCF (vcf_file , lazy = vcf_lazy )
159+ self .variant_fetcher = self . _read_variants (vcf_file , variants , variant_fetcher , vcf_lazy )
114160 self .interval_attrs = interval_attrs
115161 self .pr = self ._read_intervals (gtf_path , bed_path , pranges ,
116162 intervals , interval_attrs , duplicate_attr = True )
117163 self .variant_batch_size = variant_batch_size
118164
165+ @staticmethod
166+ def _read_variants (
167+ vcf_file = None ,
168+ variants = None ,
169+ variant_fetcher = None ,
170+ vcf_lazy : bool = True ,
171+ ):
172+ if vcf_file is not None :
173+ from kipoiseq .extractors import MultiSampleVCF
174+ vcf = MultiSampleVCF (vcf_file , lazy = vcf_lazy )
175+
176+ if os .environ .get ('PYTEST_RUNNING' , '' ) == 'true' :
177+ # Ensure that none of the methods actually uses MultiSampleVCF methods by accident
178+ vcf = VariantFetcherProxy (vcf )
179+
180+ return vcf
181+ elif variant_fetcher is not None :
182+ assert isinstance (variant_fetcher , VariantFetcher ), \
183+ "Wrong type of variant fetcher: %s" % type (variant_fetcher )
184+ return variant_fetcher
185+ elif variants is not None :
186+ return PyrangesVariantFetcher (variants )
187+ else :
188+ raise ValueError ("No source of variants was specified!" )
189+
119190 @staticmethod
120191 def _read_intervals (gtf_path = None , bed_path = None , pranges = None ,
121192 intervals = None , interval_attrs = None , duplicate_attr = False ):
@@ -172,7 +243,7 @@ def _read_vcf_pyranges(self, batch_size=10000):
172243 Args:
173244 batch_size: size of each batch.
174245 """
175- for batch in self .vcf .batch_iter (batch_size ):
246+ for batch in self .variant_fetcher .batch_iter (batch_size ):
176247 yield variants_to_pyranges (batch )
177248
178249 def iter_pyranges (self ) -> PyRanges :
@@ -210,30 +281,14 @@ def __iter__(self):
210281
211282class MultiVariantsMatcher (BaseVariantMatcher ):
212283
213- def __init__ (
214- self ,
215- vcf_file ,
216- gtf_path = None ,
217- pranges = None ,
218- intervals = None ,
219- interval_attrs = None ,
220- vcf_lazy = True ,
221- variant_batch_size = 10000
222- ):
223- super ().__init__ (
224- vcf_file ,
225- gtf_path = gtf_path ,
226- pranges = pranges ,
227- intervals = intervals ,
228- interval_attrs = interval_attrs ,
229- vcf_lazy = vcf_lazy ,
230- variant_batch_size = variant_batch_size
231- )
284+ def __init__ (self , * args , ** kwargs ):
285+ super ().__init__ (* args , ** kwargs )
286+
232287 if hasattr (self .pr , 'intervals' ):
233288 self .intervals = self .pr .intervals
234289 else :
235290 self .intervals = pyranges_to_intervals (self .pr )
236291
237292 def __iter__ (self ):
238293 for i in self .intervals :
239- yield i , self .vcf .fetch_variants (i )
294+ yield i , self .variant_fetcher .fetch_variants (i )
0 commit comments