Skip to content

Commit 17ef77a

Browse files
committed
Add concat_then_split packing in Grain pipeline
1 parent 6ce880e commit 17ef77a

File tree

4 files changed

+35
-3
lines changed

4 files changed

+35
-3
lines changed

src/MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ grain_train_files: ''
595595
grain_eval_files: ''
596596
grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data.
597597
grain_file_type: 'arrayrecord' # arrayrecord or parquet
598+
grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'
598599
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
599600
grain_per_worker_buffer_size: 1
600601
# num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -803,6 +803,10 @@ class DatasetGeneral(BaseModel):
803803
True,
804804
description="Whether to pack multiple short examples into a single sequence.",
805805
)
806+
grain_packing_type: Literal["first_fit", "concat_then_split"] = Field(
807+
"first_fit",
808+
description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.",
809+
)
806810
max_segments_per_seq: int = Field(
807811
32, description="Maximum number of segments that can be packed into a single sequence."
808812
)

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@ def pretrain_preprocessing_pipeline(
171171
if config.grain_file_type == "arrayrecord":
172172
dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize))
173173
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
174+
else:
175+
dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns))
174176

175177
assert len(data_columns) == 1
176178
text_column = data_columns[0]
@@ -207,11 +209,26 @@ def pretrain_preprocessing_pipeline(
207209
# But when using Grain, we want to keep the batch_size consistent with that in the checkpoint.
208210
# We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py.
209211
batch_size = batch_size // config.expansion_factor_real_data
212+
210213
if config.packing:
211214
length_struct = {col: config.max_target_length for col in data_columns}
212-
dataset = grain.experimental.FirstFitPackIterDataset(
213-
dataset, length_struct=length_struct, num_packing_bins=batch_size
214-
)
215+
if config.grain_packing_type == "first_fit":
216+
dataset = grain.experimental.FirstFitPackIterDataset(
217+
dataset, length_struct=length_struct, num_packing_bins=batch_size
218+
)
219+
elif config.grain_packing_type == "concat_then_split":
220+
if config.add_bos and hasattr(tokenizer_model, "bos_id"):
221+
dataset = grain.experimental.ConcatThenSplitIterDataset(
222+
dataset,
223+
length_struct=length_struct,
224+
bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS,
225+
bos_token_id=tokenizer_model.bos_id,
226+
)
227+
else:
228+
dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct)
229+
else:
230+
raise ValueError(f"Unkown packing type: {config.packing}")
231+
215232
rekey_dict = {
216233
"targets_segmentation": "targets_segment_ids",
217234
"inputs_segmentation": "inputs_segment_ids",

src/MaxText/input_pipeline/_input_pipeline_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,16 @@ def map(self, element):
389389
return {col: element[col].numpy() for col in self.column_names}
390390

391391

392+
@dataclasses.dataclass
393+
class KeepFeatures(grain.MapTransform):
394+
395+
def __init__(self, feature_names):
396+
self.feature_names = feature_names
397+
398+
def map(self, element):
399+
return {k: v for k, v in element.items() if k in self.feature_names}
400+
401+
392402
@dataclasses.dataclass
393403
class Rekey(grain.MapTransform):
394404
"""Rename keys according to a mapping dict"""

0 commit comments

Comments
 (0)