Skip to content

Commit af70149

Browse files
Merge pull request #2576 from AI-Hypercomputer:bernardhan/grain-worker-buffer-size
PiperOrigin-RevId: 832060301
2 parents 59b37eb + 00ec821 commit af70149

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,14 @@ hf_access_token: ''
572572
# For multiple patterns, use semicolon (;) to separate and comma (,) to specify weights.
573573
# Example: "path/to/data1.array_record*,0.3;path/to/data2.array_record*,0.7"
574574
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
575-
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
575+
# For more details, see https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md
576576
grain_train_files: ''
577577
grain_eval_files: ''
578578
grain_file_type: 'arrayrecord' # arrayrecord or parquet
579579
grain_worker_count: 1
580+
grain_per_worker_buffer_size: 1
580581
grain_worker_count_eval: 1
582+
grain_per_worker_buffer_size_eval: 1
581583
# for using pathways
582584
colocated_python_data_input: False # experimental feature, under testing
583585

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,14 @@ def get_datasets(
9494
return dataset
9595

9696

97-
def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_worker_count):
97+
def pretrain_preprocessing_pipeline(
98+
dataset,
99+
config,
100+
data_columns,
101+
tokenize,
102+
grain_worker_count,
103+
grain_per_worker_buffer_size,
104+
):
98105
"""Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
99106
if config.grain_file_type == "arrayrecord":
100107
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
@@ -159,11 +166,23 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
159166
axis=1,
160167
)
161168
)
162-
dataset = dataset.mp_prefetch(grain.MultiprocessingOptions(num_workers=grain_worker_count))
169+
dataset = dataset.mp_prefetch(
170+
grain.MultiprocessingOptions(
171+
num_workers=grain_worker_count,
172+
per_worker_buffer_size=grain_per_worker_buffer_size,
173+
)
174+
)
163175
return dataset
164176

165177

166-
def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_worker_count):
178+
def dpo_preprocessing_pipeline(
179+
dataset,
180+
config,
181+
data_columns,
182+
tokenize,
183+
grain_worker_count,
184+
grain_per_worker_buffer_size,
185+
):
167186
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
168187
if config.grain_file_type == "arrayrecord":
169188
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
@@ -190,7 +209,12 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
190209
batch_size = config.global_batch_size_to_load // jax.process_count()
191210
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
192211
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
193-
dataset = dataset.mp_prefetch(grain.MultiprocessingOptions(num_workers=grain_worker_count))
212+
dataset = dataset.mp_prefetch(
213+
grain.MultiprocessingOptions(
214+
num_workers=grain_worker_count,
215+
per_worker_buffer_size=grain_per_worker_buffer_size,
216+
)
217+
)
194218
return dataset
195219

196220

@@ -221,6 +245,7 @@ def make_grain_train_iterator(
221245
data_columns=config.train_data_columns,
222246
tokenize=config.tokenize_train_data,
223247
grain_worker_count=config.grain_worker_count,
248+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
224249
)
225250
else:
226251
train_dataloader = pretrain_preprocessing_pipeline(
@@ -229,6 +254,7 @@ def make_grain_train_iterator(
229254
data_columns=config.train_data_columns,
230255
tokenize=config.tokenize_train_data,
231256
grain_worker_count=config.grain_worker_count,
257+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
232258
)
233259
return multihost_dataloading.MultiHostDataLoadIterator(
234260
train_dataloader,
@@ -253,6 +279,7 @@ def make_grain_train_iterator(
253279
data_columns=config.train_data_columns,
254280
tokenize=config.tokenize_train_data,
255281
grain_worker_count=config.grain_worker_count,
282+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
256283
)
257284
else:
258285
preprocessing_fn = functools.partial(
@@ -261,6 +288,7 @@ def make_grain_train_iterator(
261288
data_columns=config.train_data_columns,
262289
tokenize=config.tokenize_train_data,
263290
grain_worker_count=config.grain_worker_count,
291+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
264292
)
265293
if config.colocated_python_data_input:
266294
global_shape = (config.global_batch_size_to_load, config.max_target_length)
@@ -308,6 +336,7 @@ def make_grain_eval_iterator(
308336
data_columns=config.eval_data_columns,
309337
tokenize=config.tokenize_eval_data,
310338
grain_worker_count=config.grain_worker_count_eval,
339+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
311340
)
312341
else:
313342
eval_dataloader = pretrain_preprocessing_pipeline(
@@ -316,6 +345,7 @@ def make_grain_eval_iterator(
316345
data_columns=config.eval_data_columns,
317346
tokenize=config.tokenize_eval_data,
318347
grain_worker_count=config.grain_worker_count_eval,
348+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
319349
)
320350
return multihost_dataloading.MultiHostDataLoadIterator(
321351
eval_dataloader, global_mesh, config.generate_padding_batch_eval
@@ -337,6 +367,7 @@ def make_grain_eval_iterator(
337367
data_columns=config.eval_data_columns,
338368
tokenize=config.tokenize_eval_data,
339369
grain_worker_count=config.grain_worker_count_eval,
370+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
340371
)
341372
else:
342373
preprocessing_fn = functools.partial(
@@ -345,6 +376,7 @@ def make_grain_eval_iterator(
345376
data_columns=config.eval_data_columns,
346377
tokenize=config.tokenize_eval_data,
347378
grain_worker_count=config.grain_worker_count_eval,
379+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
348380
)
349381
global_shape = (config.global_batch_size_to_load, config.max_target_length)
350382
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)

tests/grain_data_processing_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def setUp(self):
165165
grain_file_type="parquet",
166166
grain_train_files=os.path.join(temp_dir, "gcsfuse", "hf", "c4", "c4-train-00000-of-01637.parquet"),
167167
grain_worker_count=1,
168+
grain_per_worker_buffer_size=1,
168169
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
169170
enable_checkpointing=False,
170171
)

0 commit comments

Comments
 (0)