77import os
88from itertools import chain
99from functools import lru_cache
10- from typing import TYPE_CHECKING
10+ from typing import TYPE_CHECKING , Literal
1111
1212import fsspec .core
1313
@@ -104,7 +104,13 @@ def pd(self):
104104 return pd
105105
106106 def __init__ (
107- self , root , fs = None , out_root = None , cache_size = 128 , categorical_threshold = 10
107+ self ,
108+ root ,
109+ fs = None ,
110+ out_root = None ,
111+ cache_size = 128 ,
112+ categorical_threshold = 10 ,
113+ engine : Literal ["fastparquet" , "pyarrow" ] = "fastparquet" ,
108114 ):
109115 """
110116
@@ -126,16 +132,25 @@ def __init__(
126132 Encode urls as pandas.Categorical to reduce memory footprint if the ratio
127133 of the number of unique urls to total number of refs for each variable
128134 is greater than or equal to this number. (default 10)
135+ engine: Literal["fastparquet","pyarrow"]
136+ Engine choice for reading parquet files. (default is "fastparquet")
129137 """
138+
130139 self .root = root
131140 self .chunk_sizes = {}
132141 self .out_root = out_root or self .root
133142 self .cat_thresh = categorical_threshold
143+ self .engine = engine
134144 self .cache_size = cache_size
135145 self .url = self .root + "/{field}/refs.{record}.parq"
136146 # TODO: derive fs from `root`
137147 self .fs = fsspec .filesystem ("file" ) if fs is None else fs
138148
149+ from importlib .util import find_spec
150+
151+ if self .engine == "pyarrow" and find_spec ("pyarrow" ) is None :
152+ raise ImportError ("engine choice `pyarrow` is not installed." )
153+
139154 def __getattr__ (self , item ):
140155 if item in ("_items" , "record_size" , "zmetadata" ):
141156 self .setup ()
@@ -158,7 +173,7 @@ def open_refs(field, record):
158173 """cached parquet file loader"""
159174 path = self .url .format (field = field , record = record )
160175 data = io .BytesIO (self .fs .cat_file (path ))
161- df = self .pd .read_parquet (data , engine = "fastparquet" )
176+ df = self .pd .read_parquet (data , engine = self . engine )
162177 refs = {c : df [c ].to_numpy () for c in df .columns }
163178 return refs
164179
@@ -463,18 +478,28 @@ def write(self, field, record, base_url=None, storage_options=None):
463478
464479 fn = f"{ base_url or self .out_root } /{ field } /refs.{ record } .parq"
465480 self .fs .mkdirs (f"{ base_url or self .out_root } /{ field } " , exist_ok = True )
481+
482+ if self .engine == "pyarrow" :
483+ df_backend_kwargs = {"write_statistics" : False }
484+ elif self .engine == "fastparquet" :
485+ df_backend_kwargs = {
486+ "stats" : False ,
487+ "object_encoding" : object_encoding ,
488+ "has_nulls" : has_nulls ,
489+ }
490+ else :
491+ raise NotImplementedError (f"{ self .engine } not supported" )
492+
466493 df .to_parquet (
467494 fn ,
468- engine = "fastparquet" ,
495+ engine = self . engine ,
469496 storage_options = storage_options
470497 or getattr (self .fs , "storage_options" , None ),
471498 compression = "zstd" ,
472499 index = False ,
473- stats = False ,
474- object_encoding = object_encoding ,
475- has_nulls = has_nulls ,
476- # **kwargs,
500+ ** df_backend_kwargs ,
477501 )
502+
478503 partition .clear ()
479504 self ._items .pop ((field , record ))
480505
@@ -486,6 +511,7 @@ def flush(self, base_url=None, storage_options=None):
486511 base_url: str
487512 Location of the output
488513 """
514+
489515 # write what we have so far and clear sub chunks
490516 for thing in list (self ._items ):
491517 if isinstance (thing , tuple ):
0 commit comments