Skip to content

Commit 53d7829

Browse files
committed
Add concat_then_split packing in Grain pipeline
1 parent 5489efd commit 53d7829

File tree

5 files changed

+49
-3
lines changed

5 files changed

+49
-3
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ grain_train_files: ''
597597
grain_eval_files: ''
598598
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
599599
grain_file_type: 'arrayrecord' # arrayrecord or parquet
600+
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
600601
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
601602
grain_per_worker_buffer_size: 1
602603
# num_threads and prefetch_buffer_size are per-worker per-dataset.

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -852,6 +852,10 @@ class DatasetGeneral(BaseModel):
852852
True,
853853
description="Whether to pack multiple short examples into a single sequence.",
854854
)
855+
grain_packing_type: Literal["first_fit", "concat_then_split"] = Field(
856+
"first_fit",
857+
description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.",
858+
)
855859
max_segments_per_seq: int = Field(
856860
32,
857861
description="Maximum number of segments that can be packed into a single sequence.",

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def pretrain_preprocessing_pipeline(
201201
if config.grain_file_type == "arrayrecord":
202202
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
203203
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
204+
else:
205+
dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns))
204206

205207
assert len(data_columns) == 1
206208
text_column = data_columns[0]
@@ -237,11 +239,26 @@ def pretrain_preprocessing_pipeline(
237239
# But when using Grain, we want to keep the batch_size consistent with that in the checkpoint.
238240
# We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py.
239241
batch_size = batch_size // config.expansion_factor_real_data
242+
240243
if config.packing:
241244
length_struct = {col: config.max_target_length for col in data_columns}
242-
dataset = grain.experimental.FirstFitPackIterDataset(
243-
dataset, length_struct=length_struct, num_packing_bins=batch_size
244-
)
245+
if config.grain_packing_type == "first_fit":
246+
dataset = grain.experimental.FirstFitPackIterDataset(
247+
dataset, length_struct=length_struct, num_packing_bins=batch_size
248+
)
249+
elif config.grain_packing_type == "concat_then_split":
250+
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
251+
dataset = grain.experimental.ConcatThenSplitIterDataset(
252+
dataset,
253+
length_struct=length_struct,
254+
bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS,
255+
bos_token_id=tokenizer_model.bos_id,
256+
)
257+
else:
258+
dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct)
259+
else:
260+
raise ValueError(f"Unknown packing type: {config.packing}")
261+
245262
rekey_dict = {
246263
"targets_segmentation": "targets_segment_ids",
247264
"inputs_segmentation": "inputs_segment_ids",

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import warnings
1919
from threading import current_thread
20+
from typing import Any
2021
import datasets
2122
from datasets.distributed import split_dataset_by_node
2223
import grain.python as grain
@@ -389,6 +390,28 @@ def map(self, element):
389390
return {col: element[col].numpy() for col in self.column_names}
390391

391392

393+
@dataclasses.dataclass
394+
class KeepFeatures(grain.MapTransform):
395+
"""Keep only specified features in the dataset element.
396+
397+
This transform filters the input dictionary, retaining only the keys
398+
that are present in `feature_names`.
399+
"""
400+
401+
def __init__(self, feature_names: list[str]):
402+
"""Initializes the KeepFeatures transform.
403+
404+
Args:
405+
feature_names: A list of strings, where each string is the name of a
406+
feature to be kept in the dataset element.
407+
"""
408+
self.feature_names = feature_names
409+
410+
def map(self, element: dict[str, Any]) -> dict[str, Any]:
411+
"""Applies the feature filtering to the input element."""
412+
return {k: v for k, v in element.items() if k in self.feature_names}
413+
414+
392415
@dataclasses.dataclass
393416
class Rekey(grain.MapTransform):
394417
"""Rename keys according to a mapping dict"""

tests/grain_data_processing_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def setUp(self):
198198
data_sharding=["data"],
199199
base_output_directory="gs://max-experiments/",
200200
dataset_type="grain",
201+
grain_ram_budget_mb=512,
201202
grain_train_files=os.path.join(
202203
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
203204
),

0 commit comments

Comments
 (0)