Skip to content

Commit 2782f3a

Browse files
authored
Merge pull request #1 from amanas/writer-wo-pandas
fix: fixed bug related to save variant keys as pandas dataframe
2 parents 8f93a59 + 5ce97bb commit 2782f3a

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

src/dnarecords/reader.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def _pandas_safe_read_parquet(path):
8484
return pd.concat(pd.read_parquet(f) for f in files)
8585
return None
8686

87+
@staticmethod
88+
def _pandas_safe_read_json(path):
89+
import pandas as pd
90+
import tensorflow as tf
91+
92+
files = tf.io.gfile.glob(f'{path}/*.json')
93+
if files:
94+
return pd.concat(pd.read_json(f) for f in files)
95+
return None
96+
8797
def metadata(self) -> Dict[str, DataFrame]:
8898
"""Gets the metadata associated to the DNARecords dataset as a dictionary of names to pandas DataFrames.
8999
@@ -99,8 +109,10 @@ def metadata(self) -> Dict[str, DataFrame]:
99109
result = {}
100110
tree = dr.helper.DNARecordsUtils.dnarecords_tree(self._dnarecords_path)
101111
for k, v in tree.items():
102-
if k in ['skeys', 'vkeys', 'swpfs', 'vwpfs', 'swrfs', 'vwrfs', 'swpsc', 'vwpsc', 'swrsc', 'vwrsc']:
112+
if k in ['skeys', 'vkeys', 'swpfs', 'vwpfs', 'swrfs', 'vwrfs']:
103113
result.update({k: self._pandas_safe_read_parquet(v)})
114+
if k in ['swpsc', 'vwpsc', 'swrsc', 'vwrsc']:
115+
result.update({k: self._pandas_safe_read_json(v)})
104116
return result
105117

106118
def datafiles(self) -> Dict[str, List[str]]:
@@ -132,9 +144,8 @@ def datafiles(self) -> Dict[str, List[str]]:
132144
def _sw_decoder(dnarecords, schema, gzip):
133145
import json
134146
import tensorflow as tf
135-
136147
one_proto = next(iter(tf.data.TFRecordDataset(dnarecords, 'GZIP' if gzip else None)))
137-
swrsc_dict = {f['name']: f for f in schema}
148+
swrsc_dict = {f['fields']['name']: f['fields'] for _, f in schema.iterrows()}
138149
features = {'key': tf.io.FixedLenFeature([], tf.int64)}
139150
for indices_field in [field for field in swrsc_dict.keys() if field.endswith('indices')]:
140151
feature_name = indices_field.replace('_indices', '')
@@ -152,7 +163,7 @@ def _vw_decoder(dnarecords, schema, gzip):
152163
import tensorflow as tf
153164

154165
one_proto = next(iter(tf.data.TFRecordDataset(dnarecords, 'GZIP' if gzip else None)))
155-
vwrsc_dict = {f['name']: f for f in schema}
166+
vwrsc_dict = {f['fields']['name']: f['fields'] for _, f in schema.iterrows()}
156167
values_type = DNARecordsReader._types_dict()[json.loads(vwrsc_dict['values']['type'])['elementType']]
157168
dense_shape = tf.io.parse_example(one_proto, {'dense_shape': tf.io.FixedLenFeature([], tf.int64)})[
158169
'dense_shape']
@@ -193,7 +204,7 @@ def sample_wise_dataset(self, num_parallel_reads: int = -1, num_parallel_calls:
193204
schema = self.metadata()['swrsc']
194205
if schema is None or not dnarecords:
195206
raise Exception(f"No DNARecords found at {self._dnarecords_path}/...")
196-
decoder = self._sw_decoder(dnarecords, schema.fields[0], self._gzip)
207+
decoder = self._sw_decoder(dnarecords, schema, self._gzip)
197208
return self._dataset(dnarecords, decoder, num_parallel_reads, num_parallel_calls, deterministic, drop_remainder,
198209
batch_size, buffer_size)
199210

@@ -219,7 +230,7 @@ def variant_wise_dataset(self, num_parallel_reads: int = -1, num_parallel_calls:
219230
schema = self.metadata()['vwrsc']
220231
if schema is None or not dnarecords:
221232
raise Exception(f"No DNARecords found at {self._dnarecords_path}/...")
222-
decoder = self._vw_decoder(dnarecords, schema.fields[0], self._gzip)
233+
decoder = self._vw_decoder(dnarecords, schema, self._gzip)
223234
return self._dataset(dnarecords, decoder, num_parallel_reads, num_parallel_calls, deterministic, drop_remainder,
224235
batch_size, buffer_size)
225236

@@ -311,8 +322,10 @@ def metadata(self) -> Dict[str, 'DataFrame']:
311322
tree = dr.helper.DNARecordsUtils.dnarecords_tree(self._dnarecords_path)
312323
spark = dr.helper.DNARecordsUtils.spark_session()
313324
for k, v in tree.items():
314-
if k in ['skeys', 'vkeys', 'swpfs', 'vwpfs', 'swrfs', 'vwrfs', 'swpsc', 'vwpsc', 'swrsc', 'vwrsc']:
325+
if k in ['skeys', 'vkeys', 'swpfs', 'vwpfs', 'swrfs', 'vwrfs']:
315326
result.update({k: self._spark_safe_load(spark.read.format("parquet"), v)})
327+
if k in ['swpsc', 'vwpsc', 'swrsc', 'vwrsc']:
328+
result.update({k: self._spark_safe_load(spark.read.format("json"), v)})
316329
return result
317330

318331
def sample_wise_dnarecords(self) -> 'DataFrame':

src/dnarecords/writer.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,10 @@ def _set_vkeys_skeys(self):
111111
self._skeys = self._mt.key_cols_by().cols().to_spark().withColumnRenamed('j', 'key').cache()
112112

113113
def _set_chrom_ranges(self):
114-
gdf = self._vkeys.toPandas()[['locus.contig', 'key']].groupby('locus.contig', as_index=False)
115-
gdf = gdf.agg(start=('key', 'min'), end=('key', 'max'))
116-
self._chrom_ranges = {r['locus.contig']: [r['start'], r['end']] for i, r in gdf.iterrows()}
114+
from pyspark.sql import functions as F
115+
gdf = self._vkeys.select('`locus.contig`', 'key').groupby('`locus.contig`')
116+
gdf = gdf.agg(F.min('key').alias('start'), F.max('key').alias('end'))
117+
self._chrom_ranges = {r['locus.contig']: [r['start'], r['end']] for i, r in gdf.toPandas().iterrows()}
117118

118119
def _update_vkeys_by_chrom_ranges(self):
119120
from dnarecords.helper import DNARecordsUtils
@@ -129,11 +130,13 @@ def _select_ijv(self):
129130
self._mt = self._mt.select_globals().select_rows().select_cols().select_entries('v')
130131

131132
def _filter_out_undefined_entries(self):
132-
import hail as hl
133+
from dnarecords.helper import DNARecordsUtils
134+
hl = DNARecordsUtils.init_hail()
133135
self._mt = self._mt.filter_entries(hl.is_defined(self._mt.v))
134136

135137
def _filter_out_zeroes(self):
136-
import hail as hl
138+
from dnarecords.helper import DNARecordsUtils
139+
hl = DNARecordsUtils.init_hail()
137140
self._mt = self._mt.filter_entries(0 != hl.coalesce(self._mt.v, 0))
138141

139142
def _set_max_nrows_ncols(self):
@@ -163,7 +166,8 @@ def _build_ij_blocks(self):
163166

164167
def _set_ij_blocks(self):
165168
import re
166-
import hail as hl
169+
from dnarecords.helper import DNARecordsUtils
170+
hl = DNARecordsUtils.init_hail()
167171
all_blocks = [p for p in hl.hadoop_ls(f'{self._kv_blocks_path}/*') if p['is_dir']]
168172
self._i_blocks = {re.search(r'ib=(\d+)', p['path']).group(1) for p in all_blocks}
169173
self._j_blocks = {re.search(r'jb=(\d+)', p['path']).group(1) for p in all_blocks}
@@ -280,8 +284,8 @@ def _write_dnarecords(output, output_schema, dna_blocks, write_mode, gzip, tfrec
280284
if gzip:
281285
df_writer = df_writer.option("compression", "gzip")
282286
df_writer.save(output)
283-
sc_writer = spark.read.json(spark.sparkContext.parallelize([df.schema.json()])).repartition(1).write
284-
sc_writer.mode(write_mode).parquet(output_schema)
287+
sc_writer = spark.read.json(spark.sparkContext.parallelize([df.schema.json()])).coalesce(1).write
288+
sc_writer.mode(write_mode).format('json').save(output_schema)
285289

286290
@staticmethod
287291
def _write_key_files(source, output, tfrecord_format, write_mode):
@@ -294,7 +298,7 @@ def _write_key_files(source, output, tfrecord_format, write_mode):
294298
else:
295299
reader = spark.read.format("parquet")
296300
df = reader.load(source).withColumn("path", F.regexp_extract(F.input_file_name(), f"(.*){source}/(.*)", 2))
297-
df.select('key', 'path').repartition(1).write.mode(write_mode).parquet(output)
301+
df.select('key', 'path').write.mode(write_mode).parquet(output)
298302

299303
# pylint: disable=too-many-arguments
300304
# It is reasonable in this case.
@@ -365,5 +369,5 @@ def write(self, output: str, sparse: bool = True, sample_wise: bool = True, vari
365369
gzip, False)
366370
self._write_key_files(otree['swpar'], otree['swpfs'], False, write_mode)
367371

368-
self._vkeys.repartition(1).write.mode(write_mode).parquet(otree['vkeys'])
369-
self._skeys.repartition(1).write.mode(write_mode).parquet(otree['skeys'])
372+
self._vkeys.write.mode(write_mode).parquet(otree['vkeys'])
373+
self._skeys.write.mode(write_mode).parquet(otree['skeys'])

0 commit comments

Comments
 (0)