@@ -94,7 +94,14 @@ def get_datasets(
9494 return dataset
9595
9696
97- def pretrain_preprocessing_pipeline (dataset , config , data_columns , tokenize , grain_worker_count ):
97+ def pretrain_preprocessing_pipeline (
98+ dataset ,
99+ config ,
100+ data_columns ,
101+ tokenize ,
102+ grain_worker_count ,
103+ grain_per_worker_buffer_size ,
104+ ):
98105 """Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
99106 if config .grain_file_type == "arrayrecord" :
100107 dataset = dataset .map (_input_pipeline_utils .ParseFeatures (data_columns , tokenize ))
@@ -159,11 +166,23 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
159166 axis = 1 ,
160167 )
161168 )
162- dataset = dataset .mp_prefetch (grain .MultiprocessingOptions (num_workers = grain_worker_count ))
169+ dataset = dataset .mp_prefetch (
170+ grain .MultiprocessingOptions (
171+ num_workers = grain_worker_count ,
172+ per_worker_buffer_size = grain_per_worker_buffer_size ,
173+ )
174+ )
163175 return dataset
164176
165177
166- def dpo_preprocessing_pipeline (dataset , config , data_columns , tokenize , grain_worker_count ):
178+ def dpo_preprocessing_pipeline (
179+ dataset ,
180+ config ,
181+ data_columns ,
182+ tokenize ,
183+ grain_worker_count ,
184+ grain_per_worker_buffer_size ,
185+ ):
167186 """Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
168187 if config .grain_file_type == "arrayrecord" :
169188 dataset = dataset .map (_input_pipeline_utils .ParseFeatures (data_columns , tokenize ))
@@ -190,7 +209,12 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
190209 batch_size = config .global_batch_size_to_load // jax .process_count ()
191210 batch_fn = functools .partial (grain .experimental .batch_and_pad , batch_size = batch_size , pad_value = pad_id )
192211 dataset = dataset .batch (batch_size , batch_fn = batch_fn )
193- dataset = dataset .mp_prefetch (grain .MultiprocessingOptions (num_workers = grain_worker_count ))
212+ dataset = dataset .mp_prefetch (
213+ grain .MultiprocessingOptions (
214+ num_workers = grain_worker_count ,
215+ per_worker_buffer_size = grain_per_worker_buffer_size ,
216+ )
217+ )
194218 return dataset
195219
196220
@@ -221,6 +245,7 @@ def make_grain_train_iterator(
221245 data_columns = config .train_data_columns ,
222246 tokenize = config .tokenize_train_data ,
223247 grain_worker_count = config .grain_worker_count ,
248+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
224249 )
225250 else :
226251 train_dataloader = pretrain_preprocessing_pipeline (
@@ -229,6 +254,7 @@ def make_grain_train_iterator(
229254 data_columns = config .train_data_columns ,
230255 tokenize = config .tokenize_train_data ,
231256 grain_worker_count = config .grain_worker_count ,
257+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
232258 )
233259 return multihost_dataloading .MultiHostDataLoadIterator (
234260 train_dataloader ,
@@ -253,6 +279,7 @@ def make_grain_train_iterator(
253279 data_columns = config .train_data_columns ,
254280 tokenize = config .tokenize_train_data ,
255281 grain_worker_count = config .grain_worker_count ,
282+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
256283 )
257284 else :
258285 preprocessing_fn = functools .partial (
@@ -261,6 +288,7 @@ def make_grain_train_iterator(
261288 data_columns = config .train_data_columns ,
262289 tokenize = config .tokenize_train_data ,
263290 grain_worker_count = config .grain_worker_count ,
291+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
264292 )
265293 if config .colocated_python_data_input :
266294 global_shape = (config .global_batch_size_to_load , config .max_target_length )
@@ -308,6 +336,7 @@ def make_grain_eval_iterator(
308336 data_columns = config .eval_data_columns ,
309337 tokenize = config .tokenize_eval_data ,
310338 grain_worker_count = config .grain_worker_count_eval ,
339+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
311340 )
312341 else :
313342 eval_dataloader = pretrain_preprocessing_pipeline (
@@ -316,6 +345,7 @@ def make_grain_eval_iterator(
316345 data_columns = config .eval_data_columns ,
317346 tokenize = config .tokenize_eval_data ,
318347 grain_worker_count = config .grain_worker_count_eval ,
348+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
319349 )
320350 return multihost_dataloading .MultiHostDataLoadIterator (
321351 eval_dataloader , global_mesh , config .generate_padding_batch_eval
@@ -337,6 +367,7 @@ def make_grain_eval_iterator(
337367 data_columns = config .eval_data_columns ,
338368 tokenize = config .tokenize_eval_data ,
339369 grain_worker_count = config .grain_worker_count_eval ,
370+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
340371 )
341372 else :
342373 preprocessing_fn = functools .partial (
@@ -345,6 +376,7 @@ def make_grain_eval_iterator(
345376 data_columns = config .eval_data_columns ,
346377 tokenize = config .tokenize_eval_data ,
347378 grain_worker_count = config .grain_worker_count_eval ,
379+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
348380 )
349381 global_shape = (config .global_batch_size_to_load , config .max_target_length )
350382 return multihost_dataloading .RemoteIterator (get_ds_fn , preprocessing_fn , global_mesh , global_shape )
0 commit comments