Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions slurm/train_tokenizer.slurm
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
#SBATCH --job-name=train_tokenizer
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=1 # crucial - only 1 task per dist per node!
#SBATCH --cpus-per-task=40 # number of cores per tasks
#SBATCH --cpus-per-task=40 # number of cores per tasks
#SBATCH --hint=nomultithread # we get physical cores not logical
#SBATCH --partition=cpu_p1
#SBATCH --time 20:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --time 12:00:00 # maximum execution time (HH:MM:SS)
#SBATCH --output=logs/train_tokenizer/%x-%j.out # output file name
# #SBATCH --qos=qos_cpu-t4
#SBATCH --account=six@cpu

set -x -e

source $six_ALL_CCFRWORK/start-prod
conda activate thomas_data_tooling # Debug deepspeed temporarily

TOKENIZATION_REPO=$WORK/code/big_science/tokenization
TOKENIZATION_REPO=$WORK/tokenization

pushd $TOKENIZATION_REPO

echo "Sharding and compressing seed id ${SEED_ID}"

DATASET_PATH=$six_ALL_CCFRSCRATCH/tokenizer/dataset/tokenization_dataset # TODO: define where is concatenated dataset
SAVE_TOKENIZER_PATH=$six_ALL_CCFRSCRATCH/tokenizer/tokenizer
SAVE_TOKENIZER_PATH=$six_ALL_CCFRSCRATCH/tokenizer/tokenizer_equal_nfkc_24M_sentences

mkdir -p $SAVE_TOKENIZER_PATH

Expand All @@ -39,7 +40,8 @@ python train_convert_tokenizer_simple.py \
--data_name ${DATASET_PATH} \
--output_folder ${SAVE_TOKENIZER_PATH} \
--load_batch_size 1000 \
--input_sentence_size 12000000 \
--max_sequence_length 65536 \
--num_threads 80
--max_sequence_length 4096 \
--num_threads 1 \
--input_sentence_size 24_000_000 \
--normalizer nfkc

41 changes: 28 additions & 13 deletions train_convert_tokenizer_simple.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
from pathlib import Path
from typing import List
import math

import sentencepiece as spm
from datasets import load_dataset, utils
Expand All @@ -21,12 +22,12 @@ def get_args():
parser.add_argument("--load_batch_size", type=int, default=1)
parser.add_argument("--max_sequence_length", type=int, required=True)
parser.add_argument("--input_sentence_size", type=int, required=True)
parser.add_argument("--normalizer", type=str, default="nmt_nfkc")
parser.add_argument("--remove-extra-whitespaces", action="store_true")

return parser.parse_args()

def dataset_iterator(dataset, batch_size: int, sequence_length_in_byte: int):
# FIXME: we use an approximation of byte length vs byte sequence
sequence_length = sequence_length_in_byte // 2

slices = [(start, min(len(dataset), start + batch_size)) for start in range(0, len(dataset), batch_size)]
for start, end in utils.tqdm(
Expand All @@ -38,12 +39,12 @@ def dataset_iterator(dataset, batch_size: int, sequence_length_in_byte: int):
):
# Load things by batch.
batch = dataset[start: end]
batch_results = preprocess_text(batch, sequence_length)
batch_results = preprocess_text(batch, sequence_length_in_byte)
for row_results in batch_results:
for text in row_results:
yield text

def preprocess_text(batch, sequence_length: int) -> List[List[str]]:
def preprocess_text(batch, sequence_length_in_byte: int) -> List[List[str]]:
batch_results = []
for text in batch["text"]:
row_results = []
Expand All @@ -54,12 +55,21 @@ def preprocess_text(batch, sequence_length: int) -> List[List[str]]:

text = text.strip()

if len(text) == 0:
continue

# Compute an average of the number of bytes needed to encode a character for that sequence
# Needed since it will vary a lot depending on the language.
avg_bytes_per_character = math.ceil(len(text.encode('utf8')) / len(text))

sequence_length = sequence_length_in_byte // avg_bytes_per_character

# shard text to be into substrings of size < sequence length
start = 0
end = sequence_length
while end - start != 0:
if end - start <= sequence_length:
# Sort sequence: we fit everything in size one line
if end - start < sequence_length or len(text) < sequence_length:
# Short sequence: we fit everything in size one line
row_results.append(text[start: end])
start = end
else:
Expand All @@ -71,8 +81,8 @@ def preprocess_text(batch, sequence_length: int) -> List[List[str]]:
else:
substring = matches[0]

start = len(substring)
end = start + min(sequence_length, len(text))
start += len(substring)
end = min(start + sequence_length, len(text))
row_results.append(substring)

batch_results.append(row_results)
Expand All @@ -99,7 +109,7 @@ def main():
)
tokenizer_path = args.output_folder / "tokenizer"

dataset = load_dataset(args.data_name, data_files="**.jsonl.gz", split="train")
dataset = load_dataset(args.data_name, data_files="**.jsonl", split="train")

logger.info(f"Dataset length: {len(dataset)}")
# max_length = 0
Expand Down Expand Up @@ -129,7 +139,7 @@ def main():
sequence_length_in_byte=args.max_sequence_length
),
input_sentence_size=args.input_sentence_size,
shuffle_input_sentence=True,
shuffle_input_sentence=args.input_sentence_size > 0,
model_prefix=str(tokenizer_path.absolute()),
vocab_size=args.vocab_size,
model_type="bpe",
Expand All @@ -140,14 +150,17 @@ def main():
eos_id=2,
pad_id=3,
byte_fallback=True,
train_extremely_large_corpus=True
train_extremely_large_corpus=True,
normalization_rule_name=args.normalizer,
remove_extra_whitespaces=args.remove_extra_whitespaces
)

spm_model_path = tokenizer_path / f"tokenizer.model"
logger.info("Done training the tokenizer. Starting tokenizer conversion")
spm_model_path = tokenizer_path.with_suffix(".model")
original_tokenizer = SPMTokenizer(str(spm_model_path.absolute()))
converter = SpmConverter(original_tokenizer)
hf_tokenizer = converter.converted()
tokenizer_json = tokenizer_path / f"tokenizer.json"
tokenizer_json = tokenizer_path.with_suffix(".json")
hf_tokenizer.save(str(tokenizer_json.absolute()))

# WIP:
Expand All @@ -162,5 +175,7 @@ def main():
tokenizer_path / f"tokenizer_hf"
)

logger.info("Done converting and saving the tokenizer.")

if __name__ == "__main__":
main()