Skip to content

Commit 42648ba

Browse files
committed
Add concat_then_split packing in Grain pipeline
1 parent 6ce880e commit 42648ba

File tree

5 files changed

+50
-3
lines changed

5 files changed

+50
-3
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ grain_train_files: ''
595595
grain_eval_files: ''
596596
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
597597
grain_file_type: 'arrayrecord' # arrayrecord or parquet
598+
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html
598599
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
599600
grain_per_worker_buffer_size: 1
600601
# num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class DatasetGeneral(BaseModel):
803803
True,
804804
description="Whether to pack multiple short examples into a single sequence.",
805805
)
806+
grain_packing_type: Literal["first_fit", "concat_then_split"] = Field(
807+
"first_fit",
808+
description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.",
809+
)
806810
max_segments_per_seq: int = Field(
807811
32, description="Maximum number of segments that can be packed into a single sequence."
808812
)

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def pretrain_preprocessing_pipeline(
171171
if config.grain_file_type == "arrayrecord":
172172
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
173173
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
174+
else:
175+
dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns))
174176

175177
assert len(data_columns) == 1
176178
text_column = data_columns[0]
@@ -207,11 +209,26 @@ def pretrain_preprocessing_pipeline(
207209
# But when using Grain, we want to keep the batch_size consistent with that in the checkpoint.
208210
# We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py.
209211
batch_size = batch_size // config.expansion_factor_real_data
212+
210213
if config.packing:
211214
length_struct = {col: config.max_target_length for col in data_columns}
212-
dataset = grain.experimental.FirstFitPackIterDataset(
213-
dataset, length_struct=length_struct, num_packing_bins=batch_size
214-
)
215+
if config.grain_packing_type == "first_fit":
216+
dataset = grain.experimental.FirstFitPackIterDataset(
217+
dataset, length_struct=length_struct, num_packing_bins=batch_size
218+
)
219+
elif config.grain_packing_type == "concat_then_split":
220+
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
221+
dataset = grain.experimental.ConcatThenSplitIterDataset(
222+
dataset,
223+
length_struct=length_struct,
224+
bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS,
225+
bos_token_id=tokenizer_model.bos_id,
226+
)
227+
else:
228+
dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct)
229+
else:
230+
raise ValueError(f"Unknown packing type: {config.packing}")
231+
215232
rekey_dict = {
216233
"targets_segmentation": "targets_segment_ids",
217234
"inputs_segmentation": "inputs_segment_ids",

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import dataclasses
1818
import warnings
1919
from threading import current_thread
20+
from typing import Any
2021
import datasets
2122
from datasets.distributed import split_dataset_by_node
2223
import grain.python as grain
@@ -389,6 +390,29 @@ def map(self, element):
389390
return {col: element[col].numpy() for col in self.column_names}
390391

391392

393+
@dataclasses.dataclass
394+
class KeepFeatures(grain.MapTransform):
395+
"""Keep only specified features in the dataset element.
396+
397+
This transform filters the input dictionary, retaining only the keys
398+
that are present in `feature_names`.
399+
"""
400+
401+
def __init__(self, feature_names: list[str]):
402+
"""Initializes the KeepFeatures transform.
403+
404+
Args:
405+
feature_names: A list of strings, where each string is the name of a
406+
feature to be kept in the dataset element.
407+
"""
408+
self.feature_names = feature_names
409+
410+
def map(self, element: dict[str, Any]) -> dict[str, Any]:
411+
"""Applies the feature filtering to the input element.
412+
"""
413+
return {k: v for k, v in element.items() if k in self.feature_names}
414+
415+
392416
@dataclasses.dataclass
393417
class Rekey(grain.MapTransform):
394418
"""Rename keys according to a mapping dict"""

tests/grain_data_processing_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def setUp(self):
198198
data_sharding=["data"],
199199
base_output_directory="gs://max-experiments/",
200200
dataset_type="grain",
201+
grain_ram_budget_mb=512,
201202
grain_train_files=os.path.join(
202203
temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*"
203204
),

0 commit comments

Comments
 (0)