Skip to content

Commit 00ec821

Browse files
committed
Support Grain per_worker_buffer_size
1 parent 78fbeca commit 00ec821

File tree

3 files changed

+40
-5
lines changed

3 files changed

+40
-5
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -572,12 +572,14 @@ hf_access_token: ''
572572
# For multiple patterns, use semicolon (;) to separate and colon (:) to specify weights.
573573
# Example: "path/to/data1.array_record*:0.3;path/to/data2.array_record*:0.7"
574574
# Note: When using multiple files (separated by ';'), only ArrayRecord format is supported.
575-
# For more details, see https://github.com/google/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#grain-input-pipeline
575+
# For more details, see https://github.com/AI-Hypercomputer/maxtext/blob/main/docs/guides/data_input_pipeline/data_input_grain.md
576576
grain_train_files: ''
577577
grain_eval_files: ''
578578
grain_file_type: 'arrayrecord' # arrayrecord or parquet
579579
grain_worker_count: 1
580+
grain_per_worker_buffer_size: 1
580581
grain_worker_count_eval: 1
582+
grain_per_worker_buffer_size_eval: 1
581583
# for using pathways
582584
colocated_python_data_input: False # experimental feature, under testing
583585

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,14 @@ def get_datasets(
8888
return dataset
8989

9090

91-
def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_worker_count):
91+
def pretrain_preprocessing_pipeline(
92+
dataset,
93+
config,
94+
data_columns,
95+
tokenize,
96+
grain_worker_count,
97+
grain_per_worker_buffer_size,
98+
):
9299
"""Use grain pipeline to pre-process the dataset and return iterators for pretrain"""
93100
if config.grain_file_type == "arrayrecord":
94101
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
@@ -153,11 +160,23 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
153160
axis=1,
154161
)
155162
)
156-
dataset = dataset.mp_prefetch(grain.MultiprocessingOptions(num_workers=grain_worker_count))
163+
dataset = dataset.mp_prefetch(
164+
grain.MultiprocessingOptions(
165+
num_workers=grain_worker_count,
166+
per_worker_buffer_size=grain_per_worker_buffer_size,
167+
)
168+
)
157169
return dataset
158170

159171

160-
def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_worker_count):
172+
def dpo_preprocessing_pipeline(
173+
dataset,
174+
config,
175+
data_columns,
176+
tokenize,
177+
grain_worker_count,
178+
grain_per_worker_buffer_size,
179+
):
161180
"""Use grain to pre-process the dataset and return iterators for dpo fine-tuning"""
162181
if config.grain_file_type == "arrayrecord":
163182
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
@@ -184,7 +203,12 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
184203
batch_size = config.global_batch_size_to_load // jax.process_count()
185204
batch_fn = functools.partial(grain.experimental.batch_and_pad, batch_size=batch_size, pad_value=pad_id)
186205
dataset = dataset.batch(batch_size, batch_fn=batch_fn)
187-
dataset = dataset.mp_prefetch(grain.MultiprocessingOptions(num_workers=grain_worker_count))
206+
dataset = dataset.mp_prefetch(
207+
grain.MultiprocessingOptions(
208+
num_workers=grain_worker_count,
209+
per_worker_buffer_size=grain_per_worker_buffer_size,
210+
)
211+
)
188212
return dataset
189213

190214

@@ -215,6 +239,7 @@ def make_grain_train_iterator(
215239
data_columns=config.train_data_columns,
216240
tokenize=config.tokenize_train_data,
217241
grain_worker_count=config.grain_worker_count,
242+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
218243
)
219244
else:
220245
train_dataloader = pretrain_preprocessing_pipeline(
@@ -223,6 +248,7 @@ def make_grain_train_iterator(
223248
data_columns=config.train_data_columns,
224249
tokenize=config.tokenize_train_data,
225250
grain_worker_count=config.grain_worker_count,
251+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
226252
)
227253
return multihost_dataloading.MultiHostDataLoadIterator(
228254
train_dataloader,
@@ -247,6 +273,7 @@ def make_grain_train_iterator(
247273
data_columns=config.train_data_columns,
248274
tokenize=config.tokenize_train_data,
249275
grain_worker_count=config.grain_worker_count,
276+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
250277
)
251278
else:
252279
preprocessing_fn = functools.partial(
@@ -255,6 +282,7 @@ def make_grain_train_iterator(
255282
data_columns=config.train_data_columns,
256283
tokenize=config.tokenize_train_data,
257284
grain_worker_count=config.grain_worker_count,
285+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size,
258286
)
259287
if config.colocated_python_data_input:
260288
global_shape = (config.global_batch_size_to_load, config.max_target_length)
@@ -302,6 +330,7 @@ def make_grain_eval_iterator(
302330
data_columns=config.eval_data_columns,
303331
tokenize=config.tokenize_eval_data,
304332
grain_worker_count=config.grain_worker_count_eval,
333+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
305334
)
306335
else:
307336
eval_dataloader = pretrain_preprocessing_pipeline(
@@ -310,6 +339,7 @@ def make_grain_eval_iterator(
310339
data_columns=config.eval_data_columns,
311340
tokenize=config.tokenize_eval_data,
312341
grain_worker_count=config.grain_worker_count_eval,
342+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
313343
)
314344
return multihost_dataloading.MultiHostDataLoadIterator(
315345
eval_dataloader, global_mesh, config.generate_padding_batch_eval
@@ -331,6 +361,7 @@ def make_grain_eval_iterator(
331361
data_columns=config.eval_data_columns,
332362
tokenize=config.tokenize_eval_data,
333363
grain_worker_count=config.grain_worker_count_eval,
364+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
334365
)
335366
else:
336367
preprocessing_fn = functools.partial(
@@ -339,6 +370,7 @@ def make_grain_eval_iterator(
339370
data_columns=config.eval_data_columns,
340371
tokenize=config.tokenize_eval_data,
341372
grain_worker_count=config.grain_worker_count_eval,
373+
grain_per_worker_buffer_size=config.grain_per_worker_buffer_size_eval,
342374
)
343375
global_shape = (config.global_batch_size_to_load, config.max_target_length)
344376
return multihost_dataloading.RemoteIterator(get_ds_fn, preprocessing_fn, global_mesh, global_shape)

tests/grain_data_processing_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def setUp(self):
165165
grain_file_type="parquet",
166166
grain_train_files=os.path.join(temp_dir, "gcsfuse", "hf", "c4", "c4-train-00000-of-01637.parquet"),
167167
grain_worker_count=1,
168+
grain_per_worker_buffer_size=1,
168169
tokenizer_path=os.path.join(MAXTEXT_ASSETS_ROOT, "tokenizer"),
169170
enable_checkpointing=False,
170171
)

0 commit comments

Comments
 (0)