@@ -88,7 +88,14 @@ def get_datasets(
8888 return dataset
8989
9090
91- def pretrain_preprocessing_pipeline (dataset , config , data_columns , tokenize , grain_worker_count ):
91+ def pretrain_preprocessing_pipeline (
92+ dataset ,
93+ config ,
94+ data_columns ,
95+ tokenize ,
96+ grain_worker_count ,
97+ grain_per_worker_buffer_size ,
98+ ):
9299 """Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
93100 if config .grain_file_type == "arrayrecord" :
94101 dataset = dataset .map (_input_pipeline_utils .ParseFeatures (data_columns , tokenize ))
@@ -153,11 +160,23 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
153160 axis = 1 ,
154161 )
155162 )
156- dataset = dataset .mp_prefetch (grain .MultiprocessingOptions (num_workers = grain_worker_count ))
163+ dataset = dataset .mp_prefetch (
164+ grain .MultiprocessingOptions (
165+ num_workers = grain_worker_count ,
166+ per_worker_buffer_size = grain_per_worker_buffer_size ,
167+ )
168+ )
157169 return dataset
158170
159171
160- def dpo_preprocessing_pipeline (dataset , config , data_columns , tokenize , grain_worker_count ):
172+ def dpo_preprocessing_pipeline (
173+ dataset ,
174+ config ,
175+ data_columns ,
176+ tokenize ,
177+ grain_worker_count ,
178+ grain_per_worker_buffer_size ,
179+ ):
161180 """Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
162181 if config .grain_file_type == "arrayrecord" :
163182 dataset = dataset .map (_input_pipeline_utils .ParseFeatures (data_columns , tokenize ))
@@ -184,7 +203,12 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
184203 batch_size = config .global_batch_size_to_load // jax .process_count ()
185204 batch_fn = functools .partial (grain .experimental .batch_and_pad , batch_size = batch_size , pad_value = pad_id )
186205 dataset = dataset .batch (batch_size , batch_fn = batch_fn )
187- dataset = dataset .mp_prefetch (grain .MultiprocessingOptions (num_workers = grain_worker_count ))
206+ dataset = dataset .mp_prefetch (
207+ grain .MultiprocessingOptions (
208+ num_workers = grain_worker_count ,
209+ per_worker_buffer_size = grain_per_worker_buffer_size ,
210+ )
211+ )
188212 return dataset
189213
190214
@@ -215,6 +239,7 @@ def make_grain_train_iterator(
215239 data_columns = config .train_data_columns ,
216240 tokenize = config .tokenize_train_data ,
217241 grain_worker_count = config .grain_worker_count ,
242+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
218243 )
219244 else :
220245 train_dataloader = pretrain_preprocessing_pipeline (
@@ -223,6 +248,7 @@ def make_grain_train_iterator(
223248 data_columns = config .train_data_columns ,
224249 tokenize = config .tokenize_train_data ,
225250 grain_worker_count = config .grain_worker_count ,
251+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
226252 )
227253 return multihost_dataloading .MultiHostDataLoadIterator (
228254 train_dataloader ,
@@ -247,6 +273,7 @@ def make_grain_train_iterator(
247273 data_columns = config .train_data_columns ,
248274 tokenize = config .tokenize_train_data ,
249275 grain_worker_count = config .grain_worker_count ,
276+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
250277 )
251278 else :
252279 preprocessing_fn = functools .partial (
@@ -255,6 +282,7 @@ def make_grain_train_iterator(
255282 data_columns = config .train_data_columns ,
256283 tokenize = config .tokenize_train_data ,
257284 grain_worker_count = config .grain_worker_count ,
285+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size ,
258286 )
259287 if config .colocated_python_data_input :
260288 global_shape = (config .global_batch_size_to_load , config .max_target_length )
@@ -302,6 +330,7 @@ def make_grain_eval_iterator(
302330 data_columns = config .eval_data_columns ,
303331 tokenize = config .tokenize_eval_data ,
304332 grain_worker_count = config .grain_worker_count_eval ,
333+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
305334 )
306335 else :
307336 eval_dataloader = pretrain_preprocessing_pipeline (
@@ -310,6 +339,7 @@ def make_grain_eval_iterator(
310339 data_columns = config .eval_data_columns ,
311340 tokenize = config .tokenize_eval_data ,
312341 grain_worker_count = config .grain_worker_count_eval ,
342+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
313343 )
314344 return multihost_dataloading .MultiHostDataLoadIterator (
315345 eval_dataloader , global_mesh , config .generate_padding_batch_eval
@@ -331,6 +361,7 @@ def make_grain_eval_iterator(
331361 data_columns = config .eval_data_columns ,
332362 tokenize = config .tokenize_eval_data ,
333363 grain_worker_count = config .grain_worker_count_eval ,
364+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
334365 )
335366 else :
336367 preprocessing_fn = functools .partial (
@@ -339,6 +370,7 @@ def make_grain_eval_iterator(
339370 data_columns = config .eval_data_columns ,
340371 tokenize = config .tokenize_eval_data ,
341372 grain_worker_count = config .grain_worker_count_eval ,
373+ grain_per_worker_buffer_size = config .grain_per_worker_buffer_size_eval ,
342374 )
343375 global_shape = (config .global_batch_size_to_load , config .max_target_length )
344376 return multihost_dataloading .RemoteIterator (get_ds_fn , preprocessing_fn , global_mesh , global_shape )
0 commit comments