@@ -341,8 +341,6 @@ def __init__(
341341class ArrowWriter :
342342 """Shuffles and writes Examples to Arrow files."""
343343
344- _WRITER_CLASS = pa .RecordBatchStreamWriter
345-
346344 def __init__ (
347345 self ,
348346 schema : Optional [pa .Schema ] = None ,
@@ -430,7 +428,7 @@ def close(self):
430428 if self ._closable_stream and not self .stream .closed :
431429 self .stream .close () # This also closes self.pa_writer if it is opened
432430
433- def _build_writer (self , inferred_schema : pa .Schema ):
431+ def _build_schema (self , inferred_schema : pa .Schema ):
434432 schema = self .schema
435433 inferred_features = Features .from_arrow_schema (inferred_schema )
436434 if self ._features is not None :
@@ -441,19 +439,24 @@ def _build_writer(self, inferred_schema: pa.Schema):
441439 if name in fields :
442440 if inferred_field == fields [name ]:
443441 inferred_features [name ] = self ._features [name ]
444- self . _features = inferred_features
442+ features = inferred_features
445443 schema : pa .Schema = inferred_schema
446444 else :
447- self . _features = inferred_features
445+ features = inferred_features
448446 schema : pa .Schema = inferred_features .arrow_schema
447+
449448 if self .disable_nullable :
450449 schema = pa .schema (pa .field (field .name , field .type , nullable = False ) for field in schema )
451450 if self .with_metadata :
452- schema = schema .with_metadata (self ._build_metadata (DatasetInfo (features = self . _features ), self .fingerprint ))
451+ schema = schema .with_metadata (self ._build_metadata (DatasetInfo (features = features ), self .fingerprint ))
453452 else :
454453 schema = schema .with_metadata ({})
455- self ._schema = schema
456- self .pa_writer = self ._WRITER_CLASS (self .stream , schema )
454+
455+ return schema , features
456+
457+ def _build_writer (self , inferred_schema : pa .Schema ):
458+ self ._schema , self ._features = self ._build_schema (inferred_schema )
459+ self .pa_writer = pa .RecordBatchStreamWriter (self .stream , self ._schema )
457460
458461 @property
459462 def schema (self ):
@@ -674,4 +677,11 @@ def finalize(self, close_stream=True):
674677
675678
676679class ParquetWriter (ArrowWriter ):
677- _WRITER_CLASS = pq .ParquetWriter
680+ def __init__ (self , * args , cdc_options = None , ** kwargs ):
681+ super ().__init__ (* args , ** kwargs )
682+ self .cdc_options = config .DEFAULT_CDC_OPTIONS if cdc_options is None else cdc_options
683+
684+ def _build_writer (self , inferred_schema : pa .Schema ):
685+ self ._schema , self ._features = self ._build_schema (inferred_schema )
686+ self .pa_writer = pq .ParquetWriter (self .stream , self ._schema , use_content_defined_chunking = self .cdc_options )
687+ self .pa_writer .add_key_value_metadata ({"content_defined_chunking" : json .dumps (self .cdc_options )})
0 commit comments