@@ -49,12 +49,16 @@ def __init__(
4949 ):
5050 self .config = data_pipeline_config
5151 self .buffer_config = buffer_config
52+ # init input buffers
5253 input_buffer_configs = self .config .input_buffers
5354 if len (input_buffer_configs ) == 0 :
5455 raise ValueError ("input_buffers is empty in data pipeline config" )
55- self .buffers = []
56+ self .input_buffers = []
5657 for input_buffer_config in input_buffer_configs :
57- self .buffers .append (get_buffer_reader (input_buffer_config , self .buffer_config ))
58+ self .input_buffers .append (get_buffer_reader (input_buffer_config , self .buffer_config ))
59+ # init output buffer
60+ self .output_buffer = get_buffer_writer (self .config .output_buffer , self .buffer_config )
61+
5862 self .data = Dataset .from_list ([])
5963 self .original_dataclass = None
6064
@@ -79,28 +83,23 @@ def sort_by(self, key: str, reverse: bool = False, top_k: int = -1):
7983
8084 def read_from_buffer (self ):
8185 datasets = []
82- for buffer in self .buffers :
86+ for buffer in self .input_buffers :
8387 exp_list = buffer .read ()
8488 if self .original_dataclass is None :
8589 self .original_dataclass = exp_list [0 ].__class__
8690 datasets .append (Dataset .from_list ([asdict (exp ) for exp in exp_list ]))
8791 self .data = concatenate_datasets (datasets )
8892 logger .info (f"Read { len (self .data )} samples from input buffers" )
8993
90- def write_to_buffer (
91- self , output_storage_config : StorageConfig = None , buffer_config : BufferConfig = None
92- ):
93- if output_storage_config is None :
94- output_storage_config = self .config .output_buffer
95- if buffer_config is None :
96- buffer_config = self .buffer_config
97- output_buffer = get_buffer_writer (output_storage_config , buffer_config )
94+ def write_to_buffer (self ):
9895 exp_list = [dict_to_dataclass (self .original_dataclass , d ) for d in self .data .to_list ()]
99- output_buffer .write (exp_list )
100- output_buffer .release ()
96+ self .output_buffer .write (exp_list )
10197 logger .info (f"Wrote { len (self .data )} samples to output buffer" )
10298 self .data = Dataset .from_list ([])
10399
100+ def release_output_buffer (self ):
101+ self .output_buffer .release ()
102+
104103 def to_parquet (self , path : str ):
105104 self .data .to_parquet (path )
106105
0 commit comments