diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index de05461d9..776b81644 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -597,6 +597,7 @@ grain_train_files: '' grain_eval_files: '' grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture weights for Grain training data. grain_file_type: 'arrayrecord' # arrayrecord or parquet +grain_packing_type: 'first_fit' # 'first_fit' or 'concat_then_split'. See details of the corresponding module in https://google-grain.readthedocs.io/en/latest/grain.experimental.html grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html grain_per_worker_buffer_size: 1 # num_threads and prefetch_buffer_size are per-worker per-dataset. diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index 2a9f26098..50e1633fc 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -852,6 +852,10 @@ class DatasetGeneral(BaseModel): True, description="Whether to pack multiple short examples into a single sequence.", ) + grain_packing_type: Literal["first_fit", "concat_then_split"] = Field( + "first_fit", + description="Packing type when using Grain pipeline. 'first_fit' or 'concat_then_split'.", + ) max_segments_per_seq: int = Field( 32, description="Maximum number of segments that can be packed into a single sequence.", diff --git a/src/MaxText/input_pipeline/_grain_data_processing.py b/src/MaxText/input_pipeline/_grain_data_processing.py index 33ac30396..50c840603 100644 --- a/src/MaxText/input_pipeline/_grain_data_processing.py +++ b/src/MaxText/input_pipeline/_grain_data_processing.py @@ -201,6 +201,8 @@ def pretrain_preprocessing_pipeline( if config.grain_file_type == "arrayrecord": dataset = dataset.map(_input_pipeline_utils.ParseFeatures(data_columns, tokenize)) dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize)) + else: + dataset = dataset.map(_input_pipeline_utils.KeepFeatures(feature_names=data_columns)) assert len(data_columns) == 1 text_column = data_columns[0] @@ -237,11 +239,26 @@ def pretrain_preprocessing_pipeline( # But when using Grain, we want to keep the batch_size consistent with that in the checkpoint. # We revert the batch_size expansion here, but load multiple batches per step in multihost_dataloading.py. batch_size = batch_size // config.expansion_factor_real_data + if config.packing: length_struct = {col: config.max_target_length for col in data_columns} - dataset = grain.experimental.FirstFitPackIterDataset( - dataset, length_struct=length_struct, num_packing_bins=batch_size - ) + if config.grain_packing_type == "first_fit": + dataset = grain.experimental.FirstFitPackIterDataset( + dataset, length_struct=length_struct, num_packing_bins=batch_size + ) + elif config.grain_packing_type == "concat_then_split": + if config.add_bos and hasattr(tokenizer_model, "bos_id"): + dataset = grain.experimental.ConcatThenSplitIterDataset( + dataset, + length_struct=length_struct, + bos_handling=grain.experimental.BOSHandling.REPLACE_FIRST_TOKEN_WITH_BOS, + bos_token_id=tokenizer_model.bos_id, + ) + else: + dataset = grain.experimental.ConcatThenSplitIterDataset(dataset, length_struct=length_struct) + else: + raise ValueError(f"Unknown packing type: {config.packing}") + rekey_dict = { "targets_segmentation": "targets_segment_ids", "inputs_segmentation": "inputs_segment_ids", diff --git a/src/MaxText/input_pipeline/_input_pipeline_utils.py b/src/MaxText/input_pipeline/_input_pipeline_utils.py index 195a56b0a..3fad3e1a7 100644 --- a/src/MaxText/input_pipeline/_input_pipeline_utils.py +++ b/src/MaxText/input_pipeline/_input_pipeline_utils.py @@ -17,6 +17,7 @@ import dataclasses import warnings from threading import current_thread +from typing import Any import datasets from datasets.distributed import split_dataset_by_node import grain.python as grain @@ -389,6 +390,28 @@ def map(self, element): return {col: element[col].numpy() for col in self.column_names} +@dataclasses.dataclass +class KeepFeatures(grain.MapTransform): + """Keep only specified features in the dataset element. + + This transform filters the input dictionary, retaining only the keys + that are present in `feature_names`. + """ + + def __init__(self, feature_names: list[str]): + """Initializes the KeepFeatures transform. + + Args: + feature_names: A list of strings, where each string is the name of a + feature to be kept in the dataset element. + """ + self.feature_names = feature_names + + def map(self, element: dict[str, Any]) -> dict[str, Any]: + """Applies the feature filtering to the input element.""" + return {k: v for k, v in element.items() if k in self.feature_names} + + @dataclasses.dataclass class Rekey(grain.MapTransform): """Rename keys according to a mapping dict""" diff --git a/tests/grain_data_processing_test.py b/tests/grain_data_processing_test.py index 1ae759851..6d5c44b50 100644 --- a/tests/grain_data_processing_test.py +++ b/tests/grain_data_processing_test.py @@ -198,6 +198,7 @@ def setUp(self): data_sharding=["data"], base_output_directory="gs://max-experiments/", dataset_type="grain", + grain_ram_budget_mb=512, grain_train_files=os.path.join( temp_dir, "gcsfuse", "array-record", "c4", "en", "3.0.1", "c4-train.array_record*" ),