@@ -84,6 +84,16 @@ def _pandas_safe_read_parquet(path):
8484 return pd .concat (pd .read_parquet (f ) for f in files )
8585 return None
8686
87+ @staticmethod
88+ def _pandas_safe_read_json (path ):
89+ import pandas as pd
90+ import tensorflow as tf
91+
92+ files = tf .io .gfile .glob (f'{ path } /*.json' )
93+ if files :
94+ return pd .concat (pd .read_json (f ) for f in files )
95+ return None
96+
8797 def metadata (self ) -> Dict [str , DataFrame ]:
8898 """Gets the metadata associated to the DNARecords dataset as a dictionary of names to pandas DataFrames.
8999
@@ -99,8 +109,10 @@ def metadata(self) -> Dict[str, DataFrame]:
99109 result = {}
100110 tree = dr .helper .DNARecordsUtils .dnarecords_tree (self ._dnarecords_path )
101111 for k , v in tree .items ():
102- if k in ['skeys' , 'vkeys' , 'swpfs' , 'vwpfs' , 'swrfs' , 'vwrfs' , 'swpsc' , 'vwpsc' , 'swrsc' , 'vwrsc' ]:
112+ if k in ['skeys' , 'vkeys' , 'swpfs' , 'vwpfs' , 'swrfs' , 'vwrfs' ]:
103113 result .update ({k : self ._pandas_safe_read_parquet (v )})
114+ if k in ['swpsc' , 'vwpsc' , 'swrsc' , 'vwrsc' ]:
115+ result .update ({k : self ._pandas_safe_read_json (v )})
104116 return result
105117
106118 def datafiles (self ) -> Dict [str , List [str ]]:
@@ -132,9 +144,8 @@ def datafiles(self) -> Dict[str, List[str]]:
132144 def _sw_decoder (dnarecords , schema , gzip ):
133145 import json
134146 import tensorflow as tf
135-
136147 one_proto = next (iter (tf .data .TFRecordDataset (dnarecords , 'GZIP' if gzip else None )))
137- swrsc_dict = {f ['name' ]: f for f in schema }
148+ swrsc_dict = {f ['fields' ][ ' name' ]: f [ 'fields' ] for _ , f in schema . iterrows () }
138149 features = {'key' : tf .io .FixedLenFeature ([], tf .int64 )}
139150 for indices_field in [field for field in swrsc_dict .keys () if field .endswith ('indices' )]:
140151 feature_name = indices_field .replace ('_indices' , '' )
@@ -152,7 +163,7 @@ def _vw_decoder(dnarecords, schema, gzip):
152163 import tensorflow as tf
153164
154165 one_proto = next (iter (tf .data .TFRecordDataset (dnarecords , 'GZIP' if gzip else None )))
155- vwrsc_dict = {f ['name' ]: f for f in schema }
166+ vwrsc_dict = {f ['fields' ][ ' name' ]: f [ 'fields' ] for _ , f in schema . iterrows () }
156167 values_type = DNARecordsReader ._types_dict ()[json .loads (vwrsc_dict ['values' ]['type' ])['elementType' ]]
157168 dense_shape = tf .io .parse_example (one_proto , {'dense_shape' : tf .io .FixedLenFeature ([], tf .int64 )})[
158169 'dense_shape' ]
@@ -193,7 +204,7 @@ def sample_wise_dataset(self, num_parallel_reads: int = -1, num_parallel_calls:
193204 schema = self .metadata ()['swrsc' ]
194205 if schema is None or not dnarecords :
195206 raise Exception (f"No DNARecords found at { self ._dnarecords_path } /..." )
196- decoder = self ._sw_decoder (dnarecords , schema . fields [ 0 ] , self ._gzip )
207+ decoder = self ._sw_decoder (dnarecords , schema , self ._gzip )
197208 return self ._dataset (dnarecords , decoder , num_parallel_reads , num_parallel_calls , deterministic , drop_remainder ,
198209 batch_size , buffer_size )
199210
@@ -219,7 +230,7 @@ def variant_wise_dataset(self, num_parallel_reads: int = -1, num_parallel_calls:
219230 schema = self .metadata ()['vwrsc' ]
220231 if schema is None or not dnarecords :
221232 raise Exception (f"No DNARecords found at { self ._dnarecords_path } /..." )
222- decoder = self ._vw_decoder (dnarecords , schema . fields [ 0 ] , self ._gzip )
233+ decoder = self ._vw_decoder (dnarecords , schema , self ._gzip )
223234 return self ._dataset (dnarecords , decoder , num_parallel_reads , num_parallel_calls , deterministic , drop_remainder ,
224235 batch_size , buffer_size )
225236
@@ -311,8 +322,10 @@ def metadata(self) -> Dict[str, 'DataFrame']:
311322 tree = dr .helper .DNARecordsUtils .dnarecords_tree (self ._dnarecords_path )
312323 spark = dr .helper .DNARecordsUtils .spark_session ()
313324 for k , v in tree .items ():
314- if k in ['skeys' , 'vkeys' , 'swpfs' , 'vwpfs' , 'swrfs' , 'vwrfs' , 'swpsc' , 'vwpsc' , 'swrsc' , 'vwrsc' ]:
325+ if k in ['skeys' , 'vkeys' , 'swpfs' , 'vwpfs' , 'swrfs' , 'vwrfs' ]:
315326 result .update ({k : self ._spark_safe_load (spark .read .format ("parquet" ), v )})
327+ if k in ['swpsc' , 'vwpsc' , 'swrsc' , 'vwrsc' ]:
328+ result .update ({k : self ._spark_safe_load (spark .read .format ("json" ), v )})
316329 return result
317330
318331 def sample_wise_dnarecords (self ) -> 'DataFrame' :
0 commit comments