-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_tokenizer.py
More file actions
77 lines (64 loc) · 2.23 KB
/
train_tokenizer.py
File metadata and controls
77 lines (64 loc) · 2.23 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Train a byte-level BPE tokenizer (32K vocab) on SlimPajama data."""
import argparse
import os
from pathlib import Path
import pyarrow.parquet as pq
from tokenizers import Tokenizer, models, pre_tokenizers, trainers, decoders
def text_iterator(data_dir: str, max_files: int = 10):
"""Yield text from SlimPajama parquet files."""
data_path = Path(data_dir)
parquet_files = sorted(data_path.glob("train-*.parquet"))[:max_files]
for pf in parquet_files:
table = pq.read_table(pf, columns=["text"])
for text in table["text"].to_pylist():
if text:
yield text
def main():
parser = argparse.ArgumentParser(description="Train byte-level BPE tokenizer")
parser.add_argument(
"--data-dir",
type=str,
default="/data/share/hw3-data/SilmPajama/data",
help="Path to SlimPajama parquet data directory",
)
parser.add_argument(
"--vocab-size",
type=int,
default=32000,
help="Vocabulary size (default: 32000)",
)
parser.add_argument(
"--output",
type=str,
default="tokenizer.json",
help="Output tokenizer file path",
)
parser.add_argument(
"--max-files",
type=int,
default=5,
help="Max number of training parquet files to use",
)
args = parser.parse_args()
print(f"Training byte-level BPE tokenizer with vocab_size={args.vocab_size}")
print(f"Data directory: {args.data_dir}")
print(f"Using up to {args.max_files} training files")
tokenizer = Tokenizer(models.BPE())
tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
tokenizer.decoder = decoders.ByteLevel()
trainer = trainers.BpeTrainer(
vocab_size=args.vocab_size,
special_tokens=["<|endoftext|>", "<|padding|>"],
show_progress=True,
)
print("Starting tokenizer training...")
tokenizer.train_from_iterator(
text_iterator(args.data_dir, max_files=args.max_files),
trainer=trainer,
)
output_path = os.path.abspath(args.output)
tokenizer.save(output_path)
print(f"Tokenizer saved to {output_path}")
print(f"Vocab size: {tokenizer.get_vocab_size()}")
if __name__ == "__main__":
main()