diff --git a/src/MaxText/input_pipeline/_hf_data_processing.py b/src/MaxText/input_pipeline/_hf_data_processing.py index e056cd972..5936c88af 100644 --- a/src/MaxText/input_pipeline/_hf_data_processing.py +++ b/src/MaxText/input_pipeline/_hf_data_processing.py @@ -192,6 +192,7 @@ def preprocessing_pipeline( use_sft=None, sft_train_on_completion_only=True, grain_worker_count=1, # only support 0 or 1 + max_segments_per_seq=1, ): """pipeline for preprocessing HF dataset""" @@ -301,6 +302,7 @@ def lists2array(x): grain.experimental.PackAndBatchOperation( batch_size=global_batch_size // jax.process_count(), length_struct=length_struct, + max_sequences_per_bin=max_segments_per_seq, ) ) operations.append(_input_pipeline_utils.ReformatPacking(data_column_names)) @@ -386,6 +388,7 @@ def make_hf_train_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_sequences_per_bin=config.max_segments_per_seq, ) return train_iter @@ -437,5 +440,6 @@ def make_hf_eval_iterator( use_sft=config.use_sft, sft_train_on_completion_only=config.sft_train_on_completion_only, chat_template_path=config.chat_template_path, + max_sequences_per_bin=config.max_segments_per_seq, ) return eval_iter