Skip to content

Commit 38b2940

Browse files
committed
feat(input_pipeline): Add support for chunking long sequences instead of truncation
Introduces a `use_truncation` config flag in the input pipeline to control how long sequences are handled. - If `use_truncation=True` (default), sequences longer than `max_target_length` are truncated, preserving the original behavior. - If `use_truncation=False`, long sequences are chunked into multiple training examples of `max_target_length`. This PR also includes: - Refactoring the tokenizer transforms for better clarity and decoupling. - Updating the grain dependency to `grain==0.2.13` and adopting its updated API. - Clarifying documentation for the new `use_truncation` flag.
1 parent 83b3519 commit 38b2940

File tree

6 files changed

+209
-27
lines changed

6 files changed

+209
-27
lines changed

dependencies/requirements/requirements_with_jax_ai_image.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ datasets @ https://github.com/huggingface/datasets/archive/6790e138c00b87a1ddc72
44
flax>=0.11.0
55
google-api-python-client
66
google-jetstream @ https://github.com/AI-Hypercomputer/JetStream/archive/29329e8e73820993f77cfc8efe34eb2a73f5de98.zip
7-
grain[parquet]>=0.2.12
7+
grain[parquet]>=0.2.13
88
jaxtyping
99
jsonlines
1010
mlperf-logging @ https://github.com/mlcommons/logging/archive/38ab22670527888c8eb7825a4ece176fcc36a95d.zip

dependencies/requirements/requirements_with_jax_stable_stack_0_6_1_pipreqs.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ datasets==3.6.0
88
etils==1.12.2
99
evaluate==0.4.4
1010
flax==0.11.0
11-
grain==0.2.12
11+
grain==0.2.13
1212
grpcio==1.72.0rc1
1313
huggingface_hub==0.33.0
1414
jax==0.6.0

src/MaxText/configs/base.yml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,11 @@ tokenize_train_data: True # False if the dataset is pre-tokenized
495495
tokenize_eval_data: True # False if the dataset is pre-tokenized
496496
add_bos: True
497497
add_eos: True
498+
# If False, use chunking for long sequences instead of truncation.
499+
# Note: use_truncation=False is only available in grain's pretrain preprocessing pipeline.
500+
# See the TokenizeAndTrim and TokenizeAndChunk classes in
501+
# `src/MaxText/input_pipeline/_grain_tokenizer.py` for implementation details.
502+
use_truncation: True
498503

499504
# Dataset
500505
per_device_batch_size: 12.0

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
9595
dataset = dataset.map(_input_pipeline_utils.NormalizeFeatures(data_columns, tokenize))
9696

9797
assert len(data_columns) == 1
98-
rekey_dict = {"inputs": "text", "targets": "text"}
99-
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
100-
data_columns = ("inputs", "targets")
98+
text_column = data_columns[0]
10199

102100
tokenizer_model = tokenizer.build_tokenizer(
103101
config.tokenizer_path,
@@ -115,11 +113,15 @@ def pretrain_preprocessing_pipeline(dataset, config, data_columns, tokenize, gra
115113
pad_id = -1
116114

117115
if tokenize:
118-
dataset = dataset.map(
119-
_grain_tokenizer.TokenizeAndTrim(
120-
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
121-
)
122-
)
116+
if config.use_truncation:
117+
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(text_column, config.max_target_length, tokenizer_model))
118+
else:
119+
dataset = dataset.apply(_grain_tokenizer.TokenizeAndChunk(text_column, config.max_target_length, tokenizer_model))
120+
121+
data_columns = ("inputs", "targets")
122+
rekey_dict = {col: text_column for col in data_columns}
123+
dataset = dataset.map(_input_pipeline_utils.Rekey(rekey_dict))
124+
123125
# Pack and Batch examples.
124126
batch_size = config.global_batch_size_to_load // jax.process_count()
125127
if config.expansion_factor_real_data > 1:
@@ -176,11 +178,7 @@ def dpo_preprocessing_pipeline(dataset, config, data_columns, tokenize, grain_wo
176178
pad_id = -1
177179

178180
if tokenize:
179-
dataset = dataset.map(
180-
_grain_tokenizer.TokenizeAndTrim(
181-
data_columns, config.max_target_length, config.add_bos, config.add_eos, tokenizer_model
182-
)
183-
)
181+
dataset = dataset.map(_grain_tokenizer.TokenizeAndTrim(data_columns, config.max_target_length, tokenizer_model))
184182

185183
dataset = dataset.map(_input_pipeline_utils.PadOrTrimToMaxLength(config.max_target_length, pad_id))
186184
batch_size = config.global_batch_size_to_load // jax.process_count()

src/MaxText/input_pipeline/_grain_tokenizer.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,34 @@
2424

2525

2626
@dataclasses.dataclass
27-
class TokenizeAndTrim(grain.MapTransform):
28-
"""Tokenize and trim features to sequence length."""
27+
class TokenizerTransformBase:
28+
"""Base class for tokenizer transforms with common functionality."""
2929

3030
# pylint: disable=attribute-defined-outside-init
3131
feature_names: str | Sequence[str]
3232
sequence_length: int | Sequence[int]
33-
add_bos: bool
34-
add_eos: bool
3533
tokenizer: tokenizer.SentencePieceTokenizerGrain | tokenizer.HFTokenizer
3634

3735
def __post_init__(self):
3836
self._processor = None
3937
self._initialize_processor_lock = threading.Lock()
38+
# Convert single values to lists for consistent processing
4039
if isinstance(self.feature_names, str):
4140
self.feature_names = [self.feature_names]
4241
if isinstance(self.sequence_length, int):
4342
self.sequence_length = [self.sequence_length] * len(self.feature_names)
4443

45-
def map(self, element: dict[str, Any]) -> dict[str, Any]:
46-
"""Maps to each element."""
44+
def _get_processor(self):
4745
if self._processor is None:
4846
with self._initialize_processor_lock:
49-
if self._processor is None: # Ensures only one thread initializes SPP.
47+
if self._processor is None: # Ensures only one thread initializes processor.
5048
self._processor = self.tokenizer
51-
for feature_name, sequence_length in zip(self.feature_names, self.sequence_length, strict=True):
52-
text = element[feature_name]
53-
token_ids = self._processor.encode(text)[:sequence_length]
54-
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
55-
return element
49+
return self._processor
50+
51+
def _encode(self, text: str) -> list[int]:
52+
"""Common method to encode text using the tokenizer."""
53+
processor = self._get_processor()
54+
return processor.encode(text)
5655

5756
def __getstate__(self):
5857
state = self.__dict__.copy()
@@ -64,3 +63,49 @@ def __setstate__(self, state):
6463
self.__dict__.update(state)
6564
self._processor = None
6665
self._initialize_processor_lock = threading.Lock()
66+
67+
68+
@dataclasses.dataclass
69+
class TokenizeAndTrim(TokenizerTransformBase, grain.MapTransform):
70+
"""Tokenize and trim features to sequence length."""
71+
72+
def map(self, element: dict[str, Any]) -> dict[str, Any]:
73+
"""Maps to each element."""
74+
for feature_name, max_length in zip(self.feature_names, self.sequence_length, strict=True):
75+
text = element[feature_name]
76+
token_ids = self._encode(text)[:max_length]
77+
element[feature_name] = np.asarray(token_ids, dtype=np.int32)
78+
return element
79+
80+
81+
@dataclasses.dataclass
82+
class TokenizeAndChunk(TokenizerTransformBase, grain.experimental.FlatMapTransform):
83+
"""Tokenize and chunk features into multiple examples of sequence length."""
84+
85+
max_fan_out: int = 2048
86+
87+
def __post_init__(self):
88+
super().__post_init__()
89+
# TokenizeAndChunk only supports single feature for chunking
90+
assert len(self.feature_names) == 1, "TokenizeAndChunk only supports single feature name"
91+
assert len(self.sequence_length) == 1, "TokenizeAndChunk only supports single sequence length"
92+
self.feature_name = self.feature_names[0] # For backward compatibility
93+
self.sequence_length = self.sequence_length[0] # Convert back to int for chunking
94+
95+
def flat_map(self, element: dict[str, Any]) -> list[dict[str, Any]]:
96+
"""Tokenize and chunk text into multiple examples of sequence length."""
97+
text = element[self.feature_name]
98+
chunk_size = self.sequence_length
99+
100+
token_ids = self._encode(text)
101+
102+
if not token_ids:
103+
return []
104+
105+
output_elements = []
106+
for start_idx in range(0, len(token_ids), chunk_size):
107+
chunk = np.asarray(token_ids[start_idx : start_idx + chunk_size], dtype=np.int32)
108+
new_element = {self.feature_name: chunk}
109+
output_elements.append(new_element)
110+
111+
return output_elements

tests/tokenizer_transform_test.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2023–2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
""" Tests for tokenizer
16+
"""
17+
18+
import unittest
19+
20+
import grain.python as grain
21+
import numpy as np
22+
from MaxText.input_pipeline import _grain_tokenizer
23+
from MaxText.input_pipeline import _input_pipeline_utils
24+
from numpy.testing import assert_array_equal
25+
26+
27+
class MockTokenizer:
28+
"""
29+
Mocks a tokenizer by splitting on space and mapping letters to simple ints.
30+
e.g., "a b c" -> [1, 2, 3]
31+
"""
32+
33+
def encode(self, text: str) -> list[int]:
34+
if not text:
35+
return []
36+
# Simple 'a'=1, 'b'=2, ... mapping
37+
return [ord(c) - ord("a") + 1 for c in text.split(" ")]
38+
39+
40+
class TokenizerTransformTest(unittest.TestCase):
41+
"""Tests for chunking, trimming, and padding transformations."""
42+
43+
def setUp(self):
44+
self.max_len = 5
45+
self.pad_length = 7
46+
self.pad_id = 0
47+
self.feature_names = "text"
48+
self.mock_tokenizer = MockTokenizer()
49+
self.source_data = [{"text": "a b c"}, {"text": "d e f g h i j"}, {"text": ""}, {"text": "k l m n o p q r s t"}]
50+
self.base_ds = grain.MapDataset.source(self.source_data).to_iter_dataset()
51+
52+
def test_tokenize_and_trim(self):
53+
"""Tests the 1:1 MapTransform (truncation) logic."""
54+
trim_op = _grain_tokenizer.TokenizeAndTrim(
55+
feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer
56+
)
57+
trim_ds = self.base_ds.map(trim_op)
58+
results = list(trim_ds)
59+
self.assertEqual(len(results), len(self.source_data))
60+
expected_inputs = [
61+
np.array([1, 2, 3], dtype=np.int32),
62+
np.array([4, 5, 6, 7, 8], dtype=np.int32),
63+
np.array([], dtype=np.int32),
64+
np.array([11, 12, 13, 14, 15], dtype=np.int32),
65+
]
66+
result_inputs = [r["text"] for r in results]
67+
self.assertEqual(len(result_inputs), len(expected_inputs))
68+
for res, exp in zip(result_inputs, expected_inputs):
69+
assert_array_equal(res, exp)
70+
71+
def test_tokenize_and_chunk(self):
72+
"""Tests the 1:N FlatMapTransform (chunking) logic."""
73+
chunk_op = _grain_tokenizer.TokenizeAndChunk(
74+
feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer
75+
)
76+
chunk_ds = self.base_ds.apply(chunk_op)
77+
results = list(chunk_ds)
78+
self.assertEqual(len(results), 5)
79+
expected_inputs = [
80+
np.array([1, 2, 3], dtype=np.int32),
81+
np.array([4, 5, 6, 7, 8], dtype=np.int32),
82+
np.array([9, 10], dtype=np.int32),
83+
np.array([11, 12, 13, 14, 15], dtype=np.int32),
84+
np.array([16, 17, 18, 19, 20], dtype=np.int32),
85+
]
86+
result_inputs = [r["text"] for r in results]
87+
self.assertEqual(len(result_inputs), len(expected_inputs))
88+
for res, exp in zip(result_inputs, expected_inputs):
89+
assert_array_equal(res, exp)
90+
91+
def test_trim_and_pad_chaining(self):
92+
"""Tests chaining TokenizeAndTrim.map() -> PadOrTrimToMaxLength.map()"""
93+
trim_op = _grain_tokenizer.TokenizeAndTrim(
94+
feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer
95+
)
96+
pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id)
97+
chained_ds = self.base_ds.map(trim_op).map(pad_op)
98+
results = list(chained_ds)
99+
self.assertEqual(len(results), len(self.source_data))
100+
expected_inputs = [
101+
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
102+
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
103+
np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32),
104+
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32),
105+
]
106+
result_inputs = [r["text"] for r in results]
107+
self.assertEqual(len(result_inputs), len(expected_inputs))
108+
for res, exp in zip(result_inputs, expected_inputs):
109+
assert_array_equal(res, exp)
110+
111+
def test_chunk_and_pad_chaining(self):
112+
"""Tests chaining TokenizeAndChunk.apply() -> PadOrTrimToMaxLength.map()"""
113+
chunk_op = _grain_tokenizer.TokenizeAndChunk(
114+
feature_names=self.feature_names, sequence_length=self.max_len, tokenizer=self.mock_tokenizer
115+
)
116+
pad_op = _input_pipeline_utils.PadOrTrimToMaxLength(max_length=self.pad_length, pad_id=self.pad_id)
117+
chained_ds = self.base_ds.apply(chunk_op).map(pad_op)
118+
results = list(chained_ds)
119+
self.assertEqual(len(results), 5)
120+
expected_inputs = [
121+
np.array([1, 2, 3, 0, 0, 0, 0], dtype=np.int32),
122+
np.array([4, 5, 6, 7, 8, 0, 0], dtype=np.int32),
123+
np.array([9, 10, 0, 0, 0, 0, 0], dtype=np.int32),
124+
np.array([11, 12, 13, 14, 15, 0, 0], dtype=np.int32),
125+
np.array([16, 17, 18, 19, 20, 0, 0], dtype=np.int32),
126+
]
127+
result_inputs = [r["text"] for r in results]
128+
self.assertEqual(len(result_inputs), len(expected_inputs))
129+
for res, exp in zip(result_inputs, expected_inputs):
130+
assert_array_equal(res, exp)
131+
132+
133+
if __name__ == "__main__":
134+
unittest.main()

0 commit comments

Comments
 (0)