Skip to content

Commit 030ce4f

Browse files
asolergi-nvsarahyurickayushdg
authored
Megatron tokenization pipeline (#1259)
* Add token_size to TokenizerStage metadata Signed-off-by: asolergi-nv <[email protected]> * First draft Signed-off-by: asolergi-nv <[email protected]> * bugs Signed-off-by: asolergi-nv <[email protected]> * First working prototype Signed-off-by: asolergi-nv <[email protected]> * Guard against missing vocab_size and eos_token_id Signed-off-by: asolergi-nv <[email protected]> * Before fixing OOM Signed-off-by: asolergi-nv <[email protected]> * No OOM! Signed-off-by: asolergi-nv <[email protected]> * Undo tokenizer changes Signed-off-by: asolergi-nv <[email protected]> * Match! Signed-off-by: asolergi-nv <[email protected]> * A bit of cleaning Signed-off-by: asolergi-nv <[email protected]> * move batched to utils Signed-off-by: asolergi-nv <[email protected]> * v4: Remove document indices list Signed-off-by: asolergi-nv <[email protected]> * v5: Larger writes Signed-off-by: asolergi-nv <[email protected]> * Ready Signed-off-by: asolergi-nv <[email protected]> * Add scripts checks Signed-off-by: asolergi-nv <[email protected]> * Remove comments Signed-off-by: asolergi-nv <[email protected]> * nits Signed-off-by: asolergi-nv <[email protected]> * More nits Signed-off-by: asolergi-nv <[email protected]> * disable fields arg, move load tokenizer to setup, create sequence_lenghts in process not self, add tokenizer characteristics to _metadata Signed-off-by: asolergi-nv <[email protected]> * Add tests Signed-off-by: asolergi-nv <[email protected]> * rename batch_size to tokenization_batch_size Signed-off-by: asolergi-nv <[email protected]> * Create tutorial Signed-off-by: asolergi-nv <[email protected]> * Guard bin file writes Signed-off-by: asolergi-nv <[email protected]> * Add tutorial Signed-off-by: asolergi-nv <[email protected]> * Removed tokenizer-test folder Signed-off-by: asolergi-nv <[email protected]> * Update tutorials/README.md Co-authored-by: Sarah Yurick <[email protected]> Signed-off-by: Antoni-Joan Solergibert <[email protected]> * Fix tutorials/README.md Signed-off-by: asolergi-nv <[email protected]> * address comments Signed-off-by: asolergi-nv <[email protected]> * Guard idx write Signed-off-by: asolergi-nv <[email protected]> * Add back local_files_only=True Signed-off-by: asolergi-nv <[email protected]> * Add documentation Signed-off-by: asolergi-nv <[email protected]> * Test typo Signed-off-by: asolergi-nv <[email protected]> * Apply suggestions from code review Signed-off-by: Sarah Yurick <[email protected]> * Update nemo_curator/stages/text/io/writer/megatron_tokenizer.py Signed-off-by: Sarah Yurick <[email protected]> --------- Signed-off-by: asolergi-nv <[email protected]> Signed-off-by: Antoni-Joan Solergibert <[email protected]> Signed-off-by: Sarah Yurick <[email protected]> Co-authored-by: Sarah Yurick <[email protected]> Co-authored-by: Ayush Dattagupta <[email protected]>
1 parent c94acab commit 030ce4f

File tree

9 files changed

+865
-3
lines changed

9 files changed

+865
-3
lines changed

nemo_curator/stages/text/io/writer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
# limitations under the License.
1414

1515
from nemo_curator.stages.text.io.writer.jsonl import JsonlWriter
16+
from nemo_curator.stages.text.io.writer.megatron_tokenizer import MegatronTokenizerWriter
1617
from nemo_curator.stages.text.io.writer.parquet import ParquetWriter
1718

1819
__all__ = [
1920
"JsonlWriter",
21+
"MegatronTokenizerWriter",
2022
"ParquetWriter",
2123
]
Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import struct
16+
import uuid
17+
from dataclasses import dataclass, field
18+
from typing import BinaryIO
19+
20+
import numpy as np
21+
from huggingface_hub import snapshot_download
22+
from loguru import logger
23+
from transformers import AutoTokenizer
24+
25+
import nemo_curator.stages.text.io.writer.utils as writer_utils
26+
from nemo_curator.backends.base import NodeInfo, WorkerMetadata
27+
from nemo_curator.tasks import DocumentBatch, FileGroupTask
28+
from nemo_curator.utils.file_utils import FILETYPE_TO_DEFAULT_EXTENSIONS
29+
30+
from .base import BaseWriter
31+
from .utils import batched
32+
33+
_INDEX_HEADER = b"MMIDIDX\x00\x00"
34+
35+
36+
@dataclass
37+
class MegatronTokenizerWriter(BaseWriter):
38+
"""Writer that writes a DocumentBatch to Megatron ready tokenized files."""
39+
40+
model_identifier: str | None = None
41+
cache_dir: str | None = None
42+
hf_token: str | None = None
43+
text_field: str = "text"
44+
tokenization_batch_size: int = 1000 # Renamed from batch_size to avoid shadowing ProcessingStage.batch_size
45+
append_eod: bool = False
46+
47+
# Disable the inherited fields attribute
48+
fields: list[str] | None = field(default=None, init=False, repr=False)
49+
50+
name: str = "megatron_tokenizer_writer"
51+
file_extension: list[str] = field(default_factory=lambda: FILETYPE_TO_DEFAULT_EXTENSIONS["megatron"])
52+
53+
def __post_init__(self):
54+
if self.model_identifier is None:
55+
msg = "model_identifier is required and must be provided"
56+
raise ValueError(msg)
57+
super().__post_init__()
58+
59+
def setup_on_node(self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata = None) -> None:
60+
try:
61+
snapshot_download(
62+
repo_id=self.model_identifier,
63+
cache_dir=self.cache_dir,
64+
token=self.hf_token,
65+
)
66+
except Exception as e:
67+
msg = f"Failed to download {self.model_identifier}"
68+
raise RuntimeError(msg) from e
69+
70+
def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
71+
# Load the tokenizer
72+
self.tokenizer = AutoTokenizer.from_pretrained(
73+
self.model_identifier,
74+
cache_dir=self.cache_dir,
75+
local_files_only=True,
76+
)
77+
78+
def process(self, task: DocumentBatch) -> FileGroupTask:
79+
sequence_lengths: list[int] = []
80+
# Get source files from metadata for deterministic naming
81+
if source_files := task._metadata.get("source_files"):
82+
filename = writer_utils.get_deterministic_hash(source_files, task.task_id)
83+
else:
84+
logger.warning("The task does not have source_files in metadata, using UUID for base filename")
85+
filename = uuid.uuid4().hex
86+
87+
file_prefix = self.fs.sep.join([self._fs_path, filename])
88+
for file_extension in self.file_extension:
89+
file_path = file_prefix + file_extension
90+
if self.fs.exists(file_path):
91+
logger.debug(f"File {file_path} already exists, overwriting it")
92+
93+
token_size = (
94+
-1
95+
if self.tokenizer.vocab_size is None
96+
else (4 if self.tokenizer.vocab_size > np.iinfo(np.uint16).max + 1 else 2)
97+
)
98+
if token_size == -1:
99+
logger.warning("tokenizer.vocab_size is not set, assuming 4 bytes per token (vocab_size > 65536)")
100+
token_size = 4
101+
token_dtype = np.int32 if token_size == 4 else np.uint16 # noqa: PLR2004
102+
token_dtype_code = (
103+
4 if token_size == 4 else 8 # noqa: PLR2004
104+
) # NOTE(asolergi-nv): Megatron needs this dtype code in the .idx file | https://github.com/NVIDIA/Megatron-LM/blob/64cbae55ac85cd73fbadbc3c0d715c8123c5e13b/megatron/core/datasets/indexed_dataset.py#L41
105+
106+
eod_token_id = self.tokenizer.eos_token_id if self.tokenizer.eos_token_id is not None else -1
107+
if eod_token_id == -1:
108+
logger.warning("tokenizer.eos_token_id is not set, disabling append_eod")
109+
self.append_eod = False
110+
111+
num_docs = task.num_items
112+
113+
df = task.to_pandas()
114+
115+
try:
116+
with self.fs.open(file_prefix + ".bin", "wb") as bin_file:
117+
for batch in batched(df[self.text_field], self.tokenization_batch_size):
118+
tokens_batch = self.tokenizer.batch_encode_plus(
119+
batch,
120+
padding=False,
121+
truncation=False,
122+
add_special_tokens=False,
123+
return_token_type_ids=False,
124+
return_attention_mask=False,
125+
).input_ids
126+
self.write_data(bin_file, token_dtype, eod_token_id, tokens_batch, sequence_lengths)
127+
except Exception as e:
128+
logger.error(f"Error while writing tokens to {file_prefix}: {e}")
129+
if self.fs.exists(file_prefix + ".bin"):
130+
self.fs.remove(file_prefix + ".bin")
131+
raise
132+
133+
self.write_idx_data(file_prefix, token_size, token_dtype_code, sequence_lengths)
134+
135+
logger.debug(f"Written batch to {file_prefix} with {num_docs} documents ({sum(sequence_lengths)} tokens)")
136+
137+
return FileGroupTask(
138+
task_id=task.task_id,
139+
dataset_name=task.dataset_name,
140+
data=[file_prefix + file_extension for file_extension in self.file_extension],
141+
_metadata={
142+
**task._metadata,
143+
"format": "megatron",
144+
"file_prefix": file_prefix,
145+
"num_tokens": sum(sequence_lengths),
146+
"token_size": token_size,
147+
"eod_token_id": eod_token_id,
148+
},
149+
_stage_perf=task._stage_perf,
150+
)
151+
152+
def write_data(
153+
self,
154+
bin_file: BinaryIO,
155+
token_dtype: np.dtype,
156+
eod_token_id: int,
157+
tokens_batch: list[list[int]],
158+
sequence_lengths: list[int],
159+
) -> None:
160+
"""Write tokens to the .bin file
161+
Args:
162+
tokens_batch (list[list[int]]): The batch of tokens to write
163+
"""
164+
if self.append_eod:
165+
tokens_batch = [[*tokens, eod_token_id] for tokens in tokens_batch]
166+
sequence_lengths.extend([len(tokens) for tokens in tokens_batch])
167+
tokens_batch = np.concatenate([np.array(tokens, dtype=token_dtype) for tokens in tokens_batch])
168+
bin_file.write(tokens_batch.tobytes(order="C"))
169+
170+
def write_idx_data(
171+
self, file_prefix: str, token_size: int, token_dtype_code: int, sequence_lengths: list[int]
172+
) -> None:
173+
"""Write the .idx file data"""
174+
175+
# Save .idx file
176+
# This file has:
177+
## 9 Bytes from the _INDEX_HEADER
178+
## 8 Bytes from the version (Just a "1")
179+
## 1 Byte from the token_dtype_code
180+
## 8 Bytes from the number of sequences
181+
## 8 Bytes from the number of documents
182+
## 8 Bytes from the initial document index
183+
## 20 Bytes for every sequence/document:
184+
### - 4 Bytes from the sequence length
185+
### - 8 bytes from the sequence offset
186+
### - 8 Bytes from the document index
187+
# So, if the .bin contains tokens from 35000 text sequences/documents, the .idx will have
188+
# 9+8+1+8+8+8+20*35000 = 700042 Bytes
189+
try:
190+
with self.fs.open(file_prefix + ".idx", "wb") as idx_file:
191+
# Index Header
192+
idx_file.write(_INDEX_HEADER)
193+
# Version
194+
idx_file.write(struct.pack("<Q", 1))
195+
# Numeric code for the DType
196+
idx_file.write(struct.pack("<B", token_dtype_code))
197+
198+
# Number of sequences in the dataset
199+
sequence_count = len(sequence_lengths)
200+
idx_file.write(struct.pack("<Q", sequence_count))
201+
202+
document_indices = np.arange(len(sequence_lengths) + 1, dtype=np.int64)
203+
# Number of documents in the dataset
204+
document_count = len(document_indices)
205+
idx_file.write(struct.pack("<Q", document_count))
206+
207+
# Number of tokens per sequence
208+
sequence_lengths = np.array(sequence_lengths, dtype=np.int32)
209+
idx_file.write(sequence_lengths.tobytes(order="C"))
210+
211+
# Byte offsets for all sequences
212+
sequence_pointers = np.array(self._sequence_pointers(sequence_lengths, token_size), dtype=np.int64)
213+
idx_file.write(sequence_pointers.tobytes(order="C"))
214+
215+
# Sequence indices marking the end of each document
216+
idx_file.write(document_indices.tobytes(order="C"))
217+
except Exception as e:
218+
logger.error(f"Error while writing idx data to {file_prefix}: {e}")
219+
if self.fs.exists(file_prefix + ".idx"):
220+
self.fs.remove(file_prefix + ".idx")
221+
raise
222+
223+
@staticmethod
224+
def _sequence_pointers(sequence_lengths: list[int], token_size: int) -> list[int]:
225+
"""Build the sequence pointers per the sequence lengths and dtype size
226+
227+
Args:
228+
sequence_lengths (list[int]): The length of each sequence
229+
token_size (int): The size of each token in bytes
230+
Returns:
231+
list[int]: The pointer to the beginning of each sequence
232+
"""
233+
curr_ptr = 0
234+
list_ptr = []
235+
for length in sequence_lengths:
236+
list_ptr.append(curr_ptr)
237+
curr_ptr += length * token_size
238+
return list_ptr

nemo_curator/stages/text/io/writer/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@
1313
# limitations under the License.
1414

1515
import hashlib
16+
from collections.abc import Iterable, Iterator
17+
from itertools import islice
18+
from typing import Any
19+
20+
21+
def batched(iterable: Iterable[Any], n: int) -> Iterator[tuple[Any, ...]]:
22+
"""
23+
Batch an iterable into lists of size n.
24+
25+
Args:
26+
iterable (Iterable[Any]): The iterable to batch
27+
n (int): The size of the batch
28+
29+
Returns:
30+
Iterator[tuple[...]]: An iterator of tuples, each containing n elements from the iterable
31+
"""
32+
if n < 1:
33+
msg = "n must be at least one"
34+
raise ValueError(msg)
35+
it = iter(iterable)
36+
while batch := tuple(islice(it, n)):
37+
yield batch
1638

1739

1840
def get_deterministic_hash(inputs: list[str], seed: str = "") -> str:

nemo_curator/utils/file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
FILETYPE_TO_DEFAULT_EXTENSIONS = {
3636
"parquet": [".parquet"],
3737
"jsonl": [".jsonl", ".json"],
38+
"megatron": [".bin", ".idx"],
3839
}
3940

4041

0 commit comments

Comments
 (0)