diff --git a/4B8-en-CD-FLM.sh b/4B8-en-CD-FLM.sh new file mode 100644 index 000000000..17079579d --- /dev/null +++ b/4B8-en-CD-FLM.sh @@ -0,0 +1,157 @@ +#!/bin/bash + +EXPERIMENT_NAME=4B8-en-CD-FLM +REPO_PATH=experiments/$EXPERIMENT_NAME +CHECKPOINT_PATH=$REPO_PATH/checkpoints +TENSORBOARD_PATH=$REPO_PATH/tensorboard +CODECARBON_PATH=$REPO_PATH/codecarbon +LOGS_PATH=$REPO_PATH/logs + +DATA_PATH=data/meg-gpt2-oscar-en-10k_text_document + +# XXX: edit me +GPUS_PER_NODE=8 +NNODES=1 +PP_SIZE=2 # NLAYERS must be a multiple of PP_SIZE here +TP_SIZE=1 # always fixed to the size of a single node +DP_SIZE=$((NNODES*GPUS_PER_NODE/(PP_SIZE*TP_SIZE))) # will get derived automatically by trainer + +MICRO_BATCH_SIZE=32 +GLOBAL_BATCH_SIZE=2048 +TRAIN_ITER=131_072 +SEQ_LEN=626 + + +NLAYERS=24 +NHIDDEN=4096 +NHEADS=64 +FFN_HIDDEN_SIZE=10240 +MAX_POSITION_EMBEDDING=1280 + +SAVE_INTERVAL=1500 + +OPTIMIZER_ARGS=" \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --adam-eps 1e-8 \ + --lr 2e-4 \ + --min-lr 1e-5 \ + --lr-decay-style cosine \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + " + +EXIT_OPTS=" \ + --exit-duration-in-mins 1190 \ + " + +GPT_ARGS=" \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --num-attention-heads $NHEADS \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --max-position-embeddings $SEQ_LEN \ + --position-embedding-type alibi \ + --seq-length $SEQ_LEN \ + --micro-batch-size $MICRO_BATCH_SIZE \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --train-iters $TRAIN_ITER \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path bigscience/tokenizer \ + --loss-scale 12 \ + --clip-grad 1.0 \ + --fp16 \ + --checkpoint-activations \ + $OPTIMIZER_ARGS \ + $EXIT_OPTS \ + " + +OUTPUT_ARGS=" \ + --log-interval 1 \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $TRAIN_ITER \ + --eval-iters 1 \ + --tensorboard-dir $TENSORBOARD_PATH \ + --tensorboard-queue-size 5 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + " + +ZERO_STAGE=1 + +config_json="./ds_config.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat < $config_json +{ + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "train_batch_size": $GLOBAL_BATCH_SIZE, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " + +# export LAUNCHER="python -u -m torch.distributed.launch \ +# --nproc_per_node $GPUS_PER_NODE \ +# " +# # --nnodes $NNODES \ +# # --master_addr $MASTER_ADDR \ +# # --master_port $MASTER_PORT \ + +export CMD=" \ + `pwd`/pretrain_gpt.py \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + $GPT_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + $DEEPSPEED_ARGS \ + " + + +# # clear old checkpoint as it'd mismatch while we sort things out +# rm -rf $SAVE_CHECKPOINT_PATH + + +echo $CMD + +# We create the folder where the logs and codecarbon will be stored. +mkdir -p $REPO_PATH +mkdir -p $LOGS_PATH +# to debug - add echo (it exits and prints what it would have launched) + +# python -u -m torch.distributed.launch \ +# --nproc_per_node $GPUS_PER_NODE \ +# $CMD + +deepspeed --num_gpus $GPUS_PER_NODE \ + $CMD + +# srun '$LAUNCHER --node_rank $SLURM_PROCID $CMD' 2>&1 | tee -a $LOGS_PATH/main_log.txt \ No newline at end of file diff --git a/4B8-en-ND-MLM.sh b/4B8-en-ND-MLM.sh new file mode 100644 index 000000000..75fc3e89d --- /dev/null +++ b/4B8-en-ND-MLM.sh @@ -0,0 +1,156 @@ +#!/bin/bash + +EXPERIMENT_NAME=4B8-en-ND-MLM +REPO_PATH=experiments/$EXPERIMENT_NAME +CHECKPOINT_PATH=$REPO_PATH/checkpoints +TENSORBOARD_PATH=$REPO_PATH/tensorboard +CODECARBON_PATH=$REPO_PATH/codecarbon +LOGS_PATH=$REPO_PATH/logs + +DATA_PATH=data/meg-gpt2-oscar-en-10k_text_document +TOKENIZER_PATH=bigscience-tokenizer-padded + +# XXX: edit me +GPUS_PER_NODE=8 +NNODES=1 +PP_SIZE=2 # NLAYERS must be a multiple of PP_SIZE here +TP_SIZE=1 # always fixed to the size of a single node +DP_SIZE=$((NNODES*GPUS_PER_NODE/(PP_SIZE*TP_SIZE))) # will get derived automatically by trainer + +MICRO_BATCH_SIZE=1 +GLOBAL_BATCH_SIZE=512 +TRAIN_ITER=48_562 +INPUT_LEN=1675 +TARGET_LEN=373 +SEQ_LEN=$((INPUT_LEN+TARGET_LEN)) + +NLAYERS=24 +NHIDDEN=4096 +NHEADS=64 +FFN_HIDDEN_SIZE=10240 + + +SAVE_INTERVAL=1500 + +OPTIMIZER_ARGS=" \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --adam-eps 1e-8 \ + --lr 2e-4 \ + --min-lr 1e-5 \ + --lr-decay-style cosine \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + " + +EXIT_OPTS=" \ + --exit-duration-in-mins 1190 \ + " + +GPT_ARGS=" \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --num-attention-heads $NHEADS \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --max-position-embeddings $SEQ_LEN \ + --position-embedding-type alibi \ + --seq-length $SEQ_LEN \ + --input-length $INPUT_LEN \ + --micro-batch-size $MICRO_BATCH_SIZE \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --train-iters $TRAIN_ITER \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path $TOKENIZER_PATH \ + --loss-scale 12 \ + --clip-grad 1.0 \ + --fp16 \ + --checkpoint-activations \ + $OPTIMIZER_ARGS \ + $EXIT_OPTS \ + " + +OUTPUT_ARGS=" \ + --log-interval 1 \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $TRAIN_ITER \ + --eval-iters 1 \ + --tensorboard-dir $TENSORBOARD_PATH \ + --tensorboard-queue-size 5 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + " + +ZERO_STAGE=1 + +config_json="./ds_config.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat < $config_json +{ + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "train_batch_size": $GLOBAL_BATCH_SIZE, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " + +# export LAUNCHER="python -u -m torch.distributed.launch \ +# --nproc_per_node $GPUS_PER_NODE \ +# " +# # --nnodes $NNODES \ +# # --master_addr $MASTER_ADDR \ +# # --master_port $MASTER_PORT \ + +export CMD=" \ + `pwd`/train_ND_MLM_gpt.py \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + $GPT_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + $DEEPSPEED_ARGS \ + " + + +# # clear old checkpoint as it'd mismatch while we sort things out +# rm -rf $SAVE_CHECKPOINT_PATH + + +echo $CMD + +# We create the folder where the logs and codecarbon will be stored. +mkdir -p $REPO_PATH +mkdir -p $LOGS_PATH +# to debug - add echo (it exits and prints what it would have launched) + +deepspeed --num_gpus $GPUS_PER_NODE \ + $CMD + +# srun '$LAUNCHER --node_rank $SLURM_PROCID $CMD' 2>&1 | tee -a $LOGS_PATH/main_log.txt \ No newline at end of file diff --git a/4B8-en-ND-MTF.sh b/4B8-en-ND-MTF.sh new file mode 100644 index 000000000..209732ad3 --- /dev/null +++ b/4B8-en-ND-MTF.sh @@ -0,0 +1,155 @@ +#!/bin/bash + +EXPERIMENT_NAME=4B8-en-ND-MTF +REPO_PATH=experiments/$EXPERIMENT_NAME +CHECKPOINT_PATH=$REPO_PATH/checkpoints +TENSORBOARD_PATH=$REPO_PATH/tensorboard +CODECARBON_PATH=$REPO_PATH/codecarbon +LOGS_PATH=$REPO_PATH/logs + +DATA_PATH=data/mc4-id_text_document + +# XXX: edit me +GPUS_PER_NODE=8 +NNODES=1 +PP_SIZE=2 # NLAYERS must be a multiple of PP_SIZE here +TP_SIZE=1 # always fixed to the size of a single node +DP_SIZE=$((NNODES*GPUS_PER_NODE/(PP_SIZE*TP_SIZE))) # will get derived automatically by trainer + +MICRO_BATCH_SIZE=1 +GLOBAL_BATCH_SIZE=1024 +TRAIN_ITER=10_000 +INPUT_LEN=1024 +TARGET_LEN=256 + +NLAYERS=24 +NHIDDEN=4096 +NHEADS=64 +FFN_HIDDEN_SIZE=10240 +MAX_POSITION_EMBEDDING=1280 + +SAVE_INTERVAL=1500 + +OPTIMIZER_ARGS=" \ + --optimizer adam \ + --adam-beta1 0.9 \ + --adam-beta2 0.999 \ + --adam-eps 1e-8 \ + --lr 2e-4 \ + --min-lr 1e-5 \ + --lr-decay-style cosine \ + --clip-grad 1.0 \ + --weight-decay 1e-1 \ + " + +EXIT_OPTS=" \ + --exit-duration-in-mins 1190 \ + " + +GPT_ARGS=" \ + --num-layers $NLAYERS \ + --hidden-size $NHIDDEN \ + --num-attention-heads $NHEADS \ + --ffn-hidden-size $FFN_HIDDEN_SIZE \ + --max-position-embeddings $SEQ_LEN \ + --position-embedding-type alibi \ + --encoder-seq-length $INPUT_LEN \ + --decoder-seq-length $TARGET_LEN \ + --micro-batch-size $MICRO_BATCH_SIZE \ + --global-batch-size $GLOBAL_BATCH_SIZE \ + --train-iters $TRAIN_ITER \ + --tokenizer-type PretrainedFromHF \ + --tokenizer-name-or-path bigscience/tokenizer \ + --loss-scale 12 \ + --clip-grad 1.0 \ + --fp16 \ + --checkpoint-activations \ + $OPTIMIZER_ARGS \ + $EXIT_OPTS \ + " + +OUTPUT_ARGS=" \ + --log-interval 200 \ + --save-interval $SAVE_INTERVAL \ + --eval-interval $TRAIN_ITER \ + --eval-iters 1 \ + --tensorboard-dir $TENSORBOARD_PATH \ + --tensorboard-queue-size 5 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + " + +ZERO_STAGE=1 + +config_json="./ds_config.json" + +# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size() +cat < $config_json +{ + "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE, + "train_batch_size": $GLOBAL_BATCH_SIZE, + "gradient_clipping": 1.0, + "zero_optimization": { + "stage": $ZERO_STAGE + }, + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 12 + }, + "steps_per_print": 2000, + "wall_clock_breakdown": false +} +EOT + + +DEEPSPEED_ARGS=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${ZERO_STAGE} \ + --deepspeed-activation-checkpointing \ + " + +# export LAUNCHER="python -u -m torch.distributed.launch \ +# --nproc_per_node $GPUS_PER_NODE \ +# " +# # --nnodes $NNODES \ +# # --master_addr $MASTER_ADDR \ +# # --master_port $MASTER_PORT \ + +export CMD=" \ + `pwd`/train_ND_MTF_gpt.py \ + --tensor-model-parallel-size $TP_SIZE \ + --pipeline-model-parallel-size $PP_SIZE \ + $GPT_ARGS \ + $OUTPUT_ARGS \ + --save $CHECKPOINT_PATH \ + --load $CHECKPOINT_PATH \ + --data-path $DATA_PATH \ + --data-impl mmap \ + --split 949,50,1 \ + --distributed-backend nccl \ + $DEEPSPEED_ARGS \ + " + + +# # clear old checkpoint as it'd mismatch while we sort things out +# rm -rf $SAVE_CHECKPOINT_PATH + + +echo $CMD + +# We create the folder where the logs and codecarbon will be stored. +mkdir -p $REPO_PATH +mkdir -p $LOGS_PATH +# to debug - add echo (it exits and prints what it would have launched) + +python -u -m torch.distributed.launch \ + --nproc_per_node $GPUS_PER_NODE \ + $CMD + +# srun '$LAUNCHER --node_rank $SLURM_PROCID $CMD' 2>&1 | tee -a $LOGS_PATH/main_log.txt \ No newline at end of file diff --git a/megatron/arguments.py b/megatron/arguments.py index e1a973b05..4c8a17f04 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -878,6 +878,8 @@ def __call__(self, parser, args, values, option_string=None): 'They are used for span masking in the T5 model') group.add_argument('--seq-length', type=int, default=None, help='Maximum sequence length to process.') + group.add_argument('--input-length', type=int, default=None, + help='Maximum input length to process for MLM adaptation.') group.add_argument('--encoder-seq-length', type=int, default=None, help='Maximum encoder sequence length to process.' 'This should be exclusive of --seq-length') diff --git a/megatron/data/non_causal_mlm_dataset.py b/megatron/data/non_causal_mlm_dataset.py new file mode 100644 index 000000000..aa4a45a9f --- /dev/null +++ b/megatron/data/non_causal_mlm_dataset.py @@ -0,0 +1,499 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT Non-Causal Mask Language Model Finetune Style dataset.""" + +import os +import time +import random +import collections + +import numpy as np +import torch + +from megatron import mpu, print_rank_0, get_tokenizer +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples, create_masked_lm_predictions +from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_, get_indexed_dataset_ +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, + masked_lm_prob, seed, + skip_warmup + ): + if len(data_prefix) == 1: + return _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, + seed, skip_warmup + ) + # Blending dataset. + # Parse the values. + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + max_seq_length, masked_lm_prob, + seed, skip_warmup) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + # Blend. + blending_train_dataset = None + if train_datasets: + blending_train_dataset = BlendableDataset(train_datasets, weights) + blending_valid_dataset = None + if valid_datasets: + blending_valid_dataset = BlendableDataset(valid_datasets, weights) + blending_test_dataset = None + if test_datasets: + blending_test_dataset = BlendableDataset(test_datasets, weights) + + return (blending_train_dataset, blending_valid_dataset, + blending_test_dataset) + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + max_seq_length, masked_lm_prob, seed, + skip_warmup): + """Build train, valid, and test datasets.""" + + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] - 1 + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + start_index = indexed_dataset.doc_idx[splits[index]] + end_index = indexed_dataset.doc_idx[splits[index + 1]] + print_rank_0(' sentence indices in [{}, {}) total of {} ' + 'sentences'.format(start_index, end_index, + end_index - start_index)) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + # Get the pointer to the original doc-idx so we can set it later. + doc_idx_ptr = indexed_dataset.get_doc_idx() + # Slice the doc-idx + start_index = splits[index] + # Add +1 so we can index into the dataset to get the upper bound. + end_index = splits[index + 1] + 1 + # New doc_idx view. + indexed_dataset.set_doc_idx(doc_idx_ptr[start_index:end_index]) + # Build the dataset accordingly. + kwargs = dict( + name=name, + data_prefix=data_prefix, + max_seq_length=max_seq_length, + seed=seed, + ) + dataset = NonCausalMLMDataset( + indexed_dataset=indexed_dataset, + masked_lm_prob=masked_lm_prob, + **kwargs + ) + indexed_dataset.set_doc_idx(doc_idx_ptr) + # Checks. + # assert indexed_dataset.doc_idx[0] == 0 + # assert indexed_dataset.doc_idx.shape[0] == \ + # (total_num_of_documents + 1) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +class NonCausalMLMDataset(torch.utils.data.Dataset): + + def __init__(self, name, indexed_dataset, data_prefix, + masked_lm_prob, + max_seq_length, + seed, + max_ngrams = 3): + + # Params to store. + self.name = name + self.seed = seed + self.masked_lm_prob = masked_lm_prob + self.max_seq_length = max_seq_length + + # Dataset. + self.indexed_dataset = indexed_dataset + + self.max_ngrams = max_ngrams + # T5-like span masked language modeling will fuse consecutively masked tokens to a single sentinel token. + # To ensure that the input length is `max_seq_length`, we need to increase the maximum length + # according to `masked_lm_prob` and `max_ngrams`. We can also define the label length accordingly. + expanded_inputs_length, targets_length = compute_input_and_target_lengths( + self.max_seq_length, + self.masked_lm_prob, + self.max_ngrams + ) + self.expanded_inputs_length = expanded_inputs_length + self.targets_length = targets_length + + # Build the samples mapping. + self.samples_mapping = get_samples_mapping( + self.indexed_dataset, + data_prefix, + self.name, + max_len=expanded_inputs_length + ) + + # Vocab stuff. + tokenizer = get_tokenizer() + self.vocab_id_list = list(tokenizer.inv_vocab.keys()) + self.vocab_id_to_token_dict = tokenizer.inv_vocab + self.cls_id = tokenizer.cls + self.sep_id = tokenizer.sep + self.mask_id = tokenizer.mask + self.pad_id = tokenizer.pad + self.bos_id = tokenizer.bos_token_id + self.eos_id = tokenizer.eos_token_id + self.sentinel_tokens = tokenizer.additional_special_tokens_ids + assert len(self.sentinel_tokens) > 0, "Provide the argument --vocab-extra-ids 100 to the script" + + def __len__(self): + return len(self.samples_mapping) + + def __getitem__(self, idx): + + indices = self.samples_mapping[idx] + sample = [] + for doc_idx, start_index, end_index in indices: + sample.append(self.indexed_dataset.get(doc_idx)[start_index:end_index]) + + return build_training_sample( + sample, self.expanded_inputs_length, self.vocab_id_list, + self.cls_id, self.sep_id, self.mask_id, self.pad_id, self.bos_id, self.eos_id, + self.sentinel_tokens + ) + + +def build_training_sample( + sample, expanded_inputs_length, vocab_id_list, + cls_id, sep_id, mask_id, pad_id, bos_id=None, eos_id=None, sentinel_tokens=None + ): + """Build training sample. + + Arguments: + TODO: Add description + """ + + # flatten sentences into one list + tokens = [token for sentence in sample for token in sentence] + + mask_indices = np.asarray([random_spans_noise_mask( + expanded_inputs_length, + noise_density=0.15, + mean_noise_span_length=3 + )]) + labels_mask = ~mask_indices + + input_ids_sentinel = create_sentinel_ids(mask_indices.astype(np.int8), vocab_len=len(vocab_id_list)) + labels_sentinel = create_sentinel_ids(labels_mask.astype(np.int8), vocab_len=len(vocab_id_list)) + + tokens = np.asarray([tokens]) + input_tokens_ids = filter_input_ids(tokens, input_ids_sentinel, eos_id)[0] + output_tokens_ids = filter_input_ids(tokens, labels_sentinel, eos_id)[0] + + text_tokens_ids = np.concatenate((input_tokens_ids, output_tokens_ids)) + + prefix_len = len(input_tokens_ids) + + return { + 'text': text_tokens_ids, + 'prefix_len': prefix_len + } + + +def get_samples_mapping(indexed_dataset, data_prefix, name, max_len): + + def breakdown(sample_len, idx_offset=None, idx_list=None, max_len=None): + + if idx_list is None: + idx_list = [] + + if idx_offset is None: + idx_offset = 0 + + if sample_len < max_len: + idx_list.append(idx_offset+sample_len) + else: + sample_len = sample_len - max_len + idx_list.append(idx_offset+max_len) + idx_offset += max_len + + breakdown(sample_len, idx_offset=idx_offset, idx_list=idx_list) + + idx_list = [0]+idx_list + return list(zip(idx_list[:-1], idx_list[1:])) + + + # Filename of the index mapping + indexmap_filename = data_prefix + indexmap_filename += '_{}_indexmap'.format(name) + indexmap_filename += '.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0 and \ + not os.path.isfile(indexmap_filename): + print(' > WARNING: could not find index map file {}, building ' + 'the indices on rank 0 ...'.format(indexmap_filename)) + + # Make sure the types match the helpers input types. + assert indexed_dataset.doc_idx.dtype == np.int64 + assert indexed_dataset.sizes.dtype == np.int32 + + # Build samples mapping + verbose = torch.distributed.get_rank() == 0 + start_time = time.time() + print_rank_0(' > building sapmles index mapping for {} ...'.format( + name)) + samples_mapping = [] + sample_indices = [] + doc_idx = 0 + current_len = 0 + _idx = 0 + for doc_idx, sample_len in zip(indexed_dataset.doc_idx, indexed_dataset.sizes): + _idx = 0 + + if current_len + sample_len > max_len: + end_idx = max_len - current_len + sample_indices.append([doc_idx, 0, end_idx]) + samples_mapping.append(sample_indices) + sample_indices = [] + current_len = 0 + sample_len -= end_idx + _idx = end_idx + + break_len = current_len + sample_len + + indices = breakdown(sample_len, max_len=max_len) + for _start_idx, _end_idx in indices: + _len = _end_idx - _start_idx + if _len == max_len: + samples_mapping.append([[doc_idx, _start_idx+_idx, _end_idx+_idx]]) + else: + sample_indices.append([doc_idx, _start_idx+_idx, _end_idx+_idx]) + current_len += _len + + print_rank_0(' > done building sapmles index maping') + np.save(indexmap_filename, samples_mapping, allow_pickle=True) + print_rank_0(' > saved the index mapping in {}'.format( + indexmap_filename)) + # Make sure all the ranks have built the mapping + print_rank_0(' > elasped time to build and save samples mapping ' + '(seconds): {:4f}'.format( + time.time() - start_time)) + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load indexed dataset. + print_rank_0(' > loading indexed mapping from {}'.format( + indexmap_filename)) + start_time = time.time() + samples_mapping = np.load(indexmap_filename, allow_pickle=True) + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + len(samples_mapping))) + + return samples_mapping + + +def create_sentinel_ids(mask_indices, vocab_len): + """ + Sentinel ids creation given the indices that should be masked. + The start indices of each mask are replaced by the sentinel ids in increasing + order. Consecutive mask indices to be deleted are replaced with `-1`. + """ + start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices + start_indices[:, 0] = mask_indices[:, 0] + + sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) + sentinel_ids = np.where(sentinel_ids != 0, (vocab_len - sentinel_ids), 0) + sentinel_ids -= mask_indices - start_indices + + return sentinel_ids + + +def filter_input_ids(input_ids, sentinel_ids, eos_id): + """ + Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. + This will reduce the sequence length from `expanded_inputs_length` to `input_length`. + """ + batch_size = input_ids.shape[0] + + input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) + # input_ids tokens and sentinel tokens are >= 0, tokens < 0 are + # masked tokens coming after sentinel tokens and should be removed + input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) + input_ids = np.concatenate( + [input_ids, np.full((batch_size, 1), eos_id, dtype=np.int32)], axis=-1 + ) + return input_ids + + +def compute_input_and_target_lengths(inputs_length, noise_density, mean_noise_span_length): + """This function is copy of `random_spans_helper `__ . + Training parameters to avoid padding with random_spans_noise_mask. + When training a model with random_spans_noise_mask, we would like to set the other + training hyperparmeters in a way that avoids padding. + This function helps us compute these hyperparameters. + We assume that each noise span in the input is replaced by extra_tokens_per_span_inputs sentinel tokens, + and each non-noise span in the targets is replaced by extra_tokens_per_span_targets sentinel tokens. + This function tells us the required number of tokens in the raw example (for split_tokens()) + as well as the length of the encoded targets. Note that this function assumes + the inputs and targets will have EOS appended and includes that in the reported length. + Args: + inputs_length: an integer - desired length of the tokenized inputs sequence + noise_density: a float + mean_noise_span_length: a float + Returns: + tokens_length: length of original text in tokens + targets_length: an integer - length in tokens of encoded targets sequence + """ + + def _tokens_length_to_inputs_length_targets_length(tokens_length): + num_noise_tokens = int(round(tokens_length * noise_density)) + num_nonnoise_tokens = tokens_length - num_noise_tokens + num_noise_spans = int(round(num_noise_tokens / mean_noise_span_length)) + # inputs contain all nonnoise tokens, sentinels for all noise spans + # and one EOS token. + _input_length = num_nonnoise_tokens + num_noise_spans + 1 + _output_length = num_noise_tokens + num_noise_spans + 1 + return _input_length, _output_length + + tokens_length = inputs_length + + while _tokens_length_to_inputs_length_targets_length(tokens_length + 1)[0] <= inputs_length: + tokens_length += 1 + + inputs_length, targets_length = _tokens_length_to_inputs_length_targets_length(tokens_length) + + # minor hack to get the targets length to be equal to inputs length + # which is more likely to have been set to a nice round number. + if noise_density == 0.5 and targets_length > inputs_length: + tokens_length -= 1 + targets_length -= 1 + return tokens_length, targets_length + + +def random_spans_noise_mask( + length, + noise_density=0.15, + mean_noise_span_length=3 + ): + + """This function is copy of `random_spans_helper `__ . + Noise mask consisting of random spans of noise tokens. + The number of noise tokens and the number of noise spans and non-noise spans + are determined deterministically as follows: + num_noise_tokens = round(length * noise_density) + num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) + Spans alternate between non-noise and noise, beginning with non-noise. + Subject to the above restrictions, all masks are equally likely. + Args: + length: an int32 scalar (length of the incoming token sequence) + noise_density: a float - approximate density of output mask + mean_noise_span_length: a number + Returns: + a boolean tensor with shape [length] + """ + + orig_length = length + + num_noise_tokens = int(np.round(length * noise_density)) + # avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. + num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) + num_noise_spans = int(np.round(num_noise_tokens / mean_noise_span_length)) + + # avoid degeneracy by ensuring positive number of noise spans + num_noise_spans = max(num_noise_spans, 1) + num_nonnoise_tokens = length - num_noise_tokens + + # pick the lengths of the noise spans and the non-noise spans + def _random_segmentation(num_items, num_segments): + """Partition a sequence of items randomly into non-empty segments. + Args: + num_items: an integer scalar > 0 + num_segments: an integer scalar in [1, num_items] + Returns: + a Tensor with shape [num_segments] containing positive integers that add + up to num_items + """ + mask_indices = np.arange(num_items - 1) < (num_segments - 1) + np.random.shuffle(mask_indices) + first_in_segment = np.pad(mask_indices, [[1, 0]]) + segment_id = np.cumsum(first_in_segment) + # count length of sub segments assuming that list is sorted + _, segment_length = np.unique(segment_id, return_counts=True) + return segment_length + + noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) + nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) + + interleaved_span_lengths = np.reshape( + np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] + ) + span_starts = np.cumsum(interleaved_span_lengths)[:-1] + span_start_indicator = np.zeros((length,), dtype=np.int8) + span_start_indicator[span_starts] = True + span_num = np.cumsum(span_start_indicator) + is_noise = np.equal(span_num % 2, 1) + + return is_noise[:orig_length] \ No newline at end of file diff --git a/megatron/data/non_causal_mtf_dataset.py b/megatron/data/non_causal_mtf_dataset.py new file mode 100644 index 000000000..6bce2c4ef --- /dev/null +++ b/megatron/data/non_causal_mtf_dataset.py @@ -0,0 +1,493 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GPT Non-Causal Multitask Finetune style dataset.""" + +import os +import time + +import numpy as np +import torch + +from megatron import mpu, print_rank_0, get_tokenizer +from megatron.data.blendable_dataset import BlendableDataset +from megatron.data.dataset_utils import get_datasets_weights_and_num_samples +from megatron.data.dataset_utils import get_train_valid_test_split_, get_split_by_range_ +from megatron.data.indexed_dataset import make_dataset as make_indexed_dataset + + +def build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup): + """Build train, valid, and test datasets.""" + + # Single dataset. + if len(data_prefix) == 1: + all_train_datasets, all_valid_datasets, all_test_datasets = _build_train_valid_test_datasets(data_prefix[0], + data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup) + # Blending dataset. + else: + + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + train_datasets = [] + valid_datasets = [] + test_datasets = [] + for i in range(len(prefixes)): + train_ds, valid_ds, test_ds = _build_train_valid_test_datasets( + prefixes[i], data_impl, splits_string, + datasets_train_valid_test_num_samples[i], + seq_length, seed, skip_warmup) + if train_ds: + train_datasets.append(train_ds) + if valid_ds: + valid_datasets.append(valid_ds) + if test_ds: + test_datasets.append(test_ds) + + all_train_datasets = BlendableDataset(train_datasets, weights) \ + if train_datasets else None + all_valid_datasets = BlendableDataset(valid_datasets, weights) \ + if valid_datasets else None + all_test_datasets = BlendableDataset(test_datasets, weights) \ + if test_datasets else None + + return all_train_datasets, all_valid_datasets, all_test_datasets + + +def build_dataset_group(dataset_group_name, paths, weights, splits, data_impl, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, train_valid_test): + ''' + Build a single dataset group corresponding to Option 2 of data loading see arguments.py + a dataset group is passed on the following form + GIVEN_NAME WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT2 START:END PATH2 + or alternatively + GIVEN_NAME PATH1 # for a single dataset to be used fully + ''' + + assert train_valid_test in ["train","valid","test"] + + # Single dataset. + if len(paths) == 1: + dataset = _build_single_datasets(paths[0], + splits[0], + data_impl, + train_valid_test_num_samples, + seq_length, seed, skip_warmup, + dataset_group_name, train_valid_test) + return dataset + # Blending dataset. + else: + + data_prefix = [] + # data_prefix is on the shape: + # ["WEIGHT1", "PATH1", "WEIGHT2", "PATH2", "WEIGHT3", "PATH3"] + for w,p in zip(weights, paths): + data_prefix += [w,p] + + output = get_datasets_weights_and_num_samples(data_prefix, + train_valid_test_num_samples) + prefixes, weights, datasets_train_valid_test_num_samples = output + + # Build individual datasets. + datasets = [] + for i in range(len(prefixes)): + ds = _build_single_datasets(prefixes[i], + splits[i], + data_impl, + datasets_train_valid_test_num_samples[i], + seq_length, + seed, skip_warmup, + dataset_group_name, train_valid_test) + + datasets.append(ds) + all_datasets = BlendableDataset(datasets, weights) + + return all_datasets + +def _build_single_datasets(data_prefix, range_string, data_impl, train_valid_test_num_samples, + seq_length, seed, skip_warmup, dataset_group_name, train_valid_test): + """Build a single dataset""" + + assert train_valid_test in ["train","valid","test"] + index = ["train","valid","test"].index(train_valid_test) + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + # this corresponds to option2 for data loading on the form + # WEIGHT1 START:END PATH1, WEIGHT2 START:END PATH2, WEIGHT3 START:END PATH3 + # splits here is an array of size 2 [start_index, end_index] + splits = get_split_by_range_(range_string=range_string, size=total_num_of_documents) + + # Print stats about the splits. + print_rank_0(' > dataset split:') + + print_rank_0(' {}:'.format(dataset_group_name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[0], splits[1], + splits[1] - splits[0])) + + def build_dataset(name): + dataset = None + if splits[1] > splits[0]: + documents = np.arange(start=splits[0], stop=splits[1], + step=1, dtype=np.int32) + dataset = NonCausalMTFDataset(name, data_prefix, + documents, indexed_dataset, + train_valid_test_num_samples[index], + seq_length, seed) + return dataset + + dataset = build_dataset(dataset_group_name) + + return dataset + + +def _build_train_valid_test_datasets(data_prefix, data_impl, splits_string, + train_valid_test_num_samples, + seq_length, seed, skip_warmup): + """Build train, valid, and test datasets.""" + + + # Indexed dataset. + indexed_dataset = get_indexed_dataset_(data_prefix, + data_impl, + skip_warmup) + + total_num_of_documents = indexed_dataset.sizes.shape[0] + # splits here is an array of size 4 [train_start_index, valid_start_index, test_start_index, test_end_index] + splits = get_train_valid_test_split_(splits_string, total_num_of_documents) + # Print stats about the splits. + print_rank_0(' > dataset split:') + + def print_split_stats(name, index): + print_rank_0(' {}:'.format(name)) + print_rank_0(' document indices in [{}, {}) total of {} ' + 'documents'.format(splits[index], splits[index + 1], + splits[index + 1] - splits[index])) + print_split_stats('train', 0) + print_split_stats('validation', 1) + print_split_stats('test', 2) + + def build_dataset(index, name): + dataset = None + if splits[index + 1] > splits[index]: + documents = np.arange(start=splits[index], stop=splits[index + 1], + step=1, dtype=np.int32) + dataset = NonCausalMTFDataset(name, data_prefix, + documents, indexed_dataset, + train_valid_test_num_samples[index], + seq_length, seed) + return dataset + + train_dataset = build_dataset(0, 'train') + valid_dataset = build_dataset(1, 'valid') + test_dataset = build_dataset(2, 'test') + + return (train_dataset, valid_dataset, test_dataset) + + +def get_indexed_dataset_(path, data_impl, skip_warmup): + """Build indexed dataset.""" + print_rank_0(' > building dataset index ...') + start_time = time.time() + indexed_dataset = make_indexed_dataset(path, + data_impl, + skip_warmup) + print_rank_0(' > finished creating indexed dataset in {:4f} ' + 'seconds'.format(time.time() - start_time)) + print_rank_0(' number of documents: {}'.format( + indexed_dataset.sizes.shape[0])) + + return indexed_dataset + + +class NonCausalMTFDataset(torch.utils.data.Dataset): + + def __init__( + self, + name, + data_prefix, + documents, + indexed_dataset, + num_samples, + seq_length, + seed + ): + + # Params to store. + self.name = name + self.seq_length = seq_length + + # Dataset. + self.indexed_dataset = indexed_dataset + + # vocab + self.tokenizer = get_tokenizer() + + # Checks + assert np.min(documents) >= 0 + assert np.max(documents) < indexed_dataset.sizes.shape[0] + + # Build index mappings. + self.doc_idx, self.sample_idx, self.shuffle_idx = _build_index_mappings( + self.name, data_prefix, documents, self.indexed_dataset.sizes, + num_samples, seq_length, seed) + + def __len__(self): + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + return self.sample_idx.shape[0] - 1 + + def __getitem__(self, idx): + # Get the shuffled index. + idx = self.shuffle_idx[idx] + doc_idx = self.sample_idx[idx][0] + + sample = self.indexed_dataset.get( + self.doc_idx[doc_idx] + ) + + eod_idx = np.where(sample == self.tokenizer.eod)[0] + if len(eod_idx) > 0: + prefix_len = eod_idx[0] + else: + prefix_len = 0 + + sample = pad_and_convert_to_numpy( + sample, + self.tokenizer.pad, + self.seq_length + ) + + return { + 'text': np.array(sample, dtype=np.int64), + 'prefix_len': prefix_len + } + + +def _build_index_mappings(name, data_prefix, documents, sizes, + num_samples, seq_length, seed, cutoff_last_epoch=0.95): + """Build doc-idx, sample-idx, and shuffle-idx. + doc-idx: is an array (ordered) of documents to be used in training. + sample-idx: is the start document index and document offset for each + training sample. + shuffle-idx: maps the sample index into a random index into sample-idx. + """ + # Number of tokens in each epoch and number of required epochs. + tokens_per_epoch = _num_tokens(documents, sizes) + num_epochs = _num_epochs(tokens_per_epoch, seq_length, num_samples) + # rng state + np_rng = np.random.RandomState(seed=seed) + + # Filename of the index mappings. + _filename = data_prefix + _filename += '_{}_indexmap'.format(name) + _filename += '_{}ns'.format(num_samples) + _filename += '_{}sl'.format(seq_length) + _filename += '_{}s'.format(seed) + doc_idx_filename = _filename + '_doc_idx.npy' + sample_idx_filename = _filename + '_sample_idx.npy' + shuffle_idx_filename = _filename + '_shuffle_idx.npy' + + # Build the indexed mapping if not exist. + if torch.distributed.get_rank() == 0: + if (not os.path.isfile(doc_idx_filename)) or \ + (not os.path.isfile(sample_idx_filename)) or \ + (not os.path.isfile(shuffle_idx_filename)): + + print_rank_0(' > WARNING: could not find index map files, building ' + 'the indices on rank 0 ...') + + # For the last epoch, decide whether include the entire epoch + # in the global shuffle or not. + + # If we need only one epoch, then separating last epoch does + # not mean anything. + if num_epochs == 1: + separate_last_epoch = False + print(' > only one epoch required, setting ' + 'separate_last_epoch to False', flush=True) + + else: + # Get the number of samples for the last epoch + num_samples_from_epochs_minus_one = ( + (num_epochs - 1) * tokens_per_epoch - 1) // seq_length + last_epoch_num_samples = num_samples - \ + num_samples_from_epochs_minus_one + assert last_epoch_num_samples >= 0, \ + f'last epoch number of samples {last_epoch_num_samples} should be non-negative.' + num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length + assert last_epoch_num_samples <= num_samples_per_epoch, \ + f'last epoch number of samples {last_epoch_num_samples} exceeded max value {num_samples_per_epoch}.' + # If we have less than cutoff_last_epoch * samples_per_epoch of the samples for the last epoch, + # seperate out the epoch and treat it differently. + separate_last_epoch = (last_epoch_num_samples < + int(cutoff_last_epoch * num_samples_per_epoch)) + if separate_last_epoch: + string = ' > last epoch number of samples ({}) is smaller '\ + 'than {}% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to True' + else: + string = ' > last epoch number of samples ({}) is larger '\ + 'than {}% of number of samples per epoch ({}), '\ + 'setting separate_last_epoch to False' + print(string.format(last_epoch_num_samples, cutoff_last_epoch * 100, + num_samples_per_epoch), flush=True) + + # doc-idx. + start_time = time.time() + doc_idx = _build_doc_idx(documents, num_epochs, np_rng, + separate_last_epoch) + np.save(doc_idx_filename, doc_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save doc-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # sample-idx. + start_time = time.time() + # Use C++ implementation for speed. + # First compile and then import. + from megatron.data import helpers + assert doc_idx.dtype == np.int32 + assert sizes.dtype == np.int32 + sample_idx = helpers.build_sample_idx(sizes, doc_idx, seq_length, + num_epochs, tokens_per_epoch) + + np.save(sample_idx_filename, sample_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save sample-idx mapping ' + '(seconds): {:4f}'.format(time.time() - start_time)) + # shuffle-idx. + start_time = time.time() + # -1 is due to data structure used to retieve the index: + # sample i --> [sample_idx[i], sample_idx[i+1]) + if separate_last_epoch: + num_samples_ = num_samples_from_epochs_minus_one + else: + num_samples_ = sample_idx.shape[0] - 1 + shuffle_idx = _build_shuffle_idx(num_samples_, + sample_idx.shape[0] - 1, np_rng) + np.save(shuffle_idx_filename, shuffle_idx, allow_pickle=True) + print_rank_0(' > elasped time to build and save shuffle-idx mapping' + ' (seconds): {:4f}'.format(time.time() - start_time)) + + # This should be a barrier but nccl barrier assumes + # device_index=rank which is not the case for model + # parallel case + counts = torch.cuda.LongTensor([1]) + torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group()) + torch.distributed.all_reduce(counts, group=mpu.get_pipeline_model_parallel_group()) + assert counts[0].item() == ( + torch.distributed.get_world_size() // + torch.distributed.get_world_size(group=mpu.get_tensor_model_parallel_group())) + + # Load mappings. + start_time = time.time() + print_rank_0(' > loading doc-idx mapping from {}'.format( + doc_idx_filename)) + doc_idx = np.load(doc_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading sample-idx mapping from {}'.format( + sample_idx_filename)) + sample_idx = np.load(sample_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' > loading shuffle-idx mapping from {}'.format( + shuffle_idx_filename)) + shuffle_idx = np.load(shuffle_idx_filename, allow_pickle=True, mmap_mode='r') + print_rank_0(' loaded indexed file in {:3.3f} seconds'.format( + time.time() - start_time)) + print_rank_0(' total number of samples: {}'.format( + sample_idx.shape[0])) + print_rank_0(' total number of epochs: {}'.format(num_epochs)) + + return doc_idx, sample_idx, shuffle_idx + + +def _num_tokens(documents, sizes): + """Total number of tokens in the dataset.""" + return np.sum(sizes[documents]) + + +def _num_epochs(tokens_per_epoch, seq_length, num_samples): + """Based on number of samples and sequence lenght, calculate how many + epochs will be needed.""" + num_epochs = 0 + total_tokens = 0 + while True: + num_epochs += 1 + total_tokens += tokens_per_epoch + # -1 is because we need to retrieve seq_length + 1 token each time + # but the last token will overlap with the first token of the next + # sample except for the last sample. + if ((total_tokens - 1) // seq_length) >= num_samples: + return num_epochs + + +def _build_doc_idx(documents, num_epochs, np_rng, separate_last_epoch): + """Build an array with length = number-of-epochs * number-of-dcuments. + Each index is mapped to a corresponding document.""" + if not separate_last_epoch or num_epochs == 1: + doc_idx = np.mgrid[0:num_epochs, 0:len(documents)][1] + doc_idx[:] = documents + doc_idx = doc_idx.reshape(-1) + doc_idx = doc_idx.astype(np.int32) + np_rng.shuffle(doc_idx) + return doc_idx + + doc_idx_first = _build_doc_idx(documents, num_epochs-1, np_rng, False) + doc_idx_last = _build_doc_idx(documents, 1, np_rng, False) + return np.concatenate((doc_idx_first, doc_idx_last)) + + +def _build_shuffle_idx(num_samples, total_size, np_rng): + """Build the range [0, size) and shuffle.""" + print(' > building shuffle index with split [0, {}) and [{}, {}) ' + '...'.format(num_samples, num_samples, total_size), flush=True) + + dtype_ = np.uint32 + if total_size >= (np.iinfo(np.uint32).max - 1): + dtype_ = np.int64 + + shuffle_idx_first = np.arange(start=0, stop=num_samples, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_first) + if num_samples == total_size: + return shuffle_idx_first + + shuffle_idx_last = np.arange(start=num_samples, stop=total_size, + step=1, dtype=dtype_) + np_rng.shuffle(shuffle_idx_last) + + return np.concatenate((shuffle_idx_first, shuffle_idx_last)) + +def pad_and_convert_to_numpy(tokens, pad_id, max_seq_length): + """Pad sequences and convert them to numpy.""" + + # Some checks. + num_tokens = len(tokens) + padding_length = max_seq_length - num_tokens + assert padding_length >= 0 + + # Tokens and token types. + filler = np.array([pad_id] * padding_length) + tokens_np = np.concatenate((tokens, filler), dtype=np.int64) + + return tokens_np diff --git a/megatron/tokenizer/tokenizer.py b/megatron/tokenizer/tokenizer.py index 66f6522f2..c45468951 100644 --- a/megatron/tokenizer/tokenizer.py +++ b/megatron/tokenizer/tokenizer.py @@ -329,15 +329,15 @@ def __init__(self, tokenizer_name_or_path): @property def vocab_size(self): - return self.tokenizer.vocab_size + return self.tokenizer.__len__() #self.tokenizer.vocab_size @property def vocab(self): - return self.tokenizer.encoder + return self.tokenizer.vocab @property def inv_vocab(self): - return self.tokenizer.decoder + return {v: k for k, v in self.tokenizer.vocab.items()} def tokenize(self, text): return self.tokenizer.encode(text) @@ -348,3 +348,34 @@ def detokenize(self, token_ids): @property def eod(self): return self.tokenizer.eos_token_id + + @property + def cls(self): + return self.tokenizer.cls_token_id + + @property + def sep(self): + return self.tokenizer.sep_token_id + + @property + def pad(self): + return self.tokenizer.pad_token_id + + @property + def mask(self): + return self.tokenizer.mask_token_id + + @property + def additional_special_tokens_ids(self): + """ All the additional special tokens you may want to use (list of strings).""" + return self.tokenizer.additional_special_tokens_ids + + @property + def bos_token_id(self): + """ Id of the beginning of sentence token in the vocabulary.""" + return self.tokenizer.bos_token_id + + @property + def eos_token_id(self): + """ Id of the end of sentence token in the vocabulary.""" + return self.tokenizer.eos_token_id diff --git a/prepare_tokenizer.py b/prepare_tokenizer.py new file mode 100644 index 000000000..280ba458d --- /dev/null +++ b/prepare_tokenizer.py @@ -0,0 +1,25 @@ +from transformers import AutoTokenizer, AddedToken + +tokenizer = AutoTokenizer.from_pretrained('bigscience/tokenizer') + +tokenizer.add_special_tokens({ + 'additional_special_tokens': [ + AddedToken( + ''.format(str(idx).zfill(3)), + lstrip=False, + rstrip=False, + normalization=False + ) for idx in reversed(range(0,200)) + ] + }) + +tokenizer.save_pretrained('bigscience-tokenizer-padded') + +# python tools/preprocess_data.py \ +# --input data/oscar-en-10k.jsonl \ +# --output-prefix data/meg-gpt2-oscar-en-10k \ +# --dataset-impl mmap \ +# --tokenizer-type PretrainedFromHF \ +# --tokenizer-name-or-path bigscience-tokenizer-padded \ +# --append-eod \ +# --workers 4 \ No newline at end of file diff --git a/train_ND_MLM_gpt.py b/train_ND_MLM_gpt.py new file mode 100644 index 000000000..3f23320e8 --- /dev/null +++ b/train_ND_MLM_gpt.py @@ -0,0 +1,267 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Non-Causal Decoder GPT MLM Adaptation""" + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu + +from megatron.data.non_causal_mlm_dataset import build_train_valid_test_datasets #, build_dataset_group +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ +from megatron.utils import average_losses_across_data_parallel_group + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +import subprocess + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed: + model = GPTModelPipe( + num_tokentypes=0, + parallel_output=True, + prefix_lm=True + ) + # loaded_dir, state_dict = model[0].load_checkpoint( + # args.finetune, load_optimizer_states=False) + # if loaded_dir is None: + # print_rank_0('WARNING: could not find the metadata file {} '.format( + # load_dir)) + # print_rank_0(' will not load any checkpoints and will start from ' + # 'random') + + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + else: + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + prefix_lm=True + ) + see_memory_usage(f"After Building Model", force=True) + return model + +_KEYS = ['text', 'prefix_len'] + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = _KEYS + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Prefix + prefix_indices = data_b['prefix_len'].cpu().tolist() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=prefix_indices, + loss_on_targets_only=args.loss_on_targets_only + ) + + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + + return tokens, labels, loss_mask, attention_mask, position_ids + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = _KEYS + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Prefix + prefix_indices = data_b['prefix_len'].cpu().tolist() + + # Get the masks and position ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=prefix_indices, + loss_on_targets_only=args.loss_on_targets_only + ) + + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + + return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices + +def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + train_ds, valid_ds, test_ds = None, None, None + + print_rank_0('> building train, validation, and test datasets for GPT ...') + # Option 1 of data loading using --data-path + + if args.data_path: + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + max_seq_length=args.input_length, + masked_lm_prob=args.mask_prob, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + ) + + # # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + # elif args.train_weighted_split_paths: + # assigned_train_valid_test = [] + # if args.train_weighted_split_paths is not None: + # train_ds = [] + # assigned_train_valid_test.append("train") + # if args.valid_weighted_split_paths is not None: + # valid_ds = [] + # assigned_train_valid_test.append("valid") + # if args.test_weighted_split_paths is not None: + # test_ds = [] + # assigned_train_valid_test.append("test") + + # for s in assigned_train_valid_test: + # data_groups = zip(eval(f"args.{s}_weighted_split_paths"), + # eval(f"args.{s}_weighted_split_weights"), + # eval(f"args.{s}_weighted_split_splits"), + # eval(f"args.{s}_weighted_split_names")) + # for paths, weights, splits, name in data_groups: + # d = build_dataset_group(name, paths, weights, splits, + # args.data_impl, + # train_val_test_num_samples, + # args.seq_length, args.seed, + # (not args.mmap_warmup), + # train_valid_test=s) + # eval(f"{s}_ds").append(d) + # else: + # raise NotImplementedError("No dataloading argument passed") + + print_rank_0("> finished creating GPT datasets ...") + return train_ds, valid_ds, test_ds + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + +def git_ds_info(): + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') + + +if __name__ == "__main__": + git_ds_info() + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) diff --git a/train_ND_MTF_gpt.py b/train_ND_MTF_gpt.py new file mode 100644 index 000000000..69b8c825b --- /dev/null +++ b/train_ND_MTF_gpt.py @@ -0,0 +1,277 @@ +# coding=utf-8 +# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Non-Causal Decoder GPT Multitask Finetuning""" + +import torch +from functools import partial +from megatron import get_args +from megatron import print_rank_0 +from megatron import get_timers +from megatron import get_tokenizer +from megatron import mpu + +from megatron.data.non_causal_mtf_dataset import build_train_valid_test_datasets #, build_dataset_group +from megatron.model import GPTModel, GPTModelPipe +from megatron.training import pretrain +from megatron.utils import get_ltor_masks_and_position_ids, get_prefix_indices, reweight_loss_mask_ +from megatron.utils import average_losses_across_data_parallel_group + +import deepspeed +from deepspeed.runtime.utils import see_memory_usage +import subprocess + +def model_provider(pre_process=True, post_process=True): + """Build the model.""" + + print_rank_0('building GPT model ...') + see_memory_usage(f"Before Building Model", force=True) + + args = get_args() + + with deepspeed.zero.Init(data_parallel_group=mpu.get_data_parallel_group(), + remote_device=None if args.remote_device == 'none' else args.remote_device, + config_dict_or_path=args.deepspeed_config, + enabled=args.zero_stage == 3, + mpu=mpu): + if args.deepspeed: + model = GPTModelPipe( + num_tokentypes=0, + parallel_output=True, + prefix_lm=True + ) + # loaded_dir, state_dict = model[0].load_checkpoint( + # args.finetune, load_optimizer_states=False) + # if loaded_dir is None: + # print_rank_0('WARNING: could not find the metadata file {} '.format( + # load_dir)) + # print_rank_0(' will not load any checkpoints and will start from ' + # 'random') + + # This is a hack to give us a reference to get_batch_pipe from within training.py + # We need to call model.set_batch_fn after deepspeed.initialize + model._megatron_batch_fn = get_batch_pipe + + else: + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + prefix_lm=True + ) + see_memory_usage(f"After Building Model", force=True) + return model + +_KEYS = ['text', 'prefix_len'] + +def get_batch(data_iterator): + """Generate a batch""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = _KEYS + datatype = torch.int64 + + # Broadcast data. + if data_iterator is not None: + data = next(data_iterator) + else: + data = None + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Prefix + prefix_indices = data_b['prefix_len'].cpu().tolist() + + # Get the masks and postition ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=prefix_indices, + loss_on_targets_only=args.loss_on_targets_only + ) + + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + + return tokens, labels, loss_mask, attention_mask, position_ids + +def get_batch_pipe(data): + """Modification of `get_batch` to work on `next(data_iterator)` instead of `data_iterator`""" + args = get_args() + tokenizer = get_tokenizer() + + # Items and their type. + keys = _KEYS + datatype = torch.int64 + + # Broadcast data. + data_b = mpu.broadcast_data(keys, data, datatype) + + # Unpack. + tokens_ = data_b['text'].long() + labels = tokens_[:, 1:].contiguous() + tokens = tokens_[:, :-1].contiguous() + + # Prefix + prefix_indices = data_b['prefix_len'].cpu().tolist() + + # Get the masks and position ids. + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + tokens, + tokenizer.eod, + args.reset_position_ids, + args.reset_attention_mask, + args.eod_mask_loss, + prefix_indices=prefix_indices, + loss_on_targets_only=args.loss_on_targets_only + ) + + # weight loss_mask + if args.reweight_loss_based_on_position_frequency: + reweight_loss_mask_(loss_mask, tokens) + + return (tokens, position_ids, attention_mask), (labels, loss_mask), prefix_indices + +def loss_func(loss_mask, output_tensor): + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum() + + # Reduce loss for logging. + averaged_loss = average_losses_across_data_parallel_group([loss]) + + return loss, {'lm loss': averaged_loss[0]} + + +def forward_step(data_iterator, model): + """Forward step.""" + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator').start() + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + + return output_tensor, partial(loss_func, loss_mask) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build train, valid, and test datasets.""" + args = get_args() + train_ds, valid_ds, test_ds = None, None, None + + print_rank_0('> building train, validation, and test datasets for GPT ...') + # Option 1 of data loading using --data-path + + if args.data_path: + # train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + # data_prefix=args.data_path, + # data_impl=args.data_impl, + # splits_string=args.split, + # train_valid_test_num_samples=train_val_test_num_samples, + # seq_length=args.seq_length, + # seed=args.seed, + # skip_warmup=(not args.mmap_warmup)) + train_ds, valid_ds, test_ds = build_train_valid_test_datasets( + data_prefix=args.data_path, + data_impl=args.data_impl, + splits_string=args.split, + train_valid_test_num_samples=train_val_test_num_samples, + max_seq_length=args.encoder_seq_length, + max_seq_length_dec=args.decoder_seq_length, + masked_lm_prob=args.mask_prob, + short_seq_prob=args.short_seq_prob, + seed=args.seed, + skip_warmup=(not args.mmap_warmup), + dataset_type='t5') + + # # Option 2 of data loading using --(train|valid|test)-weighted-split-paths + # elif args.train_weighted_split_paths: + # assigned_train_valid_test = [] + # if args.train_weighted_split_paths is not None: + # train_ds = [] + # assigned_train_valid_test.append("train") + # if args.valid_weighted_split_paths is not None: + # valid_ds = [] + # assigned_train_valid_test.append("valid") + # if args.test_weighted_split_paths is not None: + # test_ds = [] + # assigned_train_valid_test.append("test") + + # for s in assigned_train_valid_test: + # data_groups = zip(eval(f"args.{s}_weighted_split_paths"), + # eval(f"args.{s}_weighted_split_weights"), + # eval(f"args.{s}_weighted_split_splits"), + # eval(f"args.{s}_weighted_split_names")) + # for paths, weights, splits, name in data_groups: + # d = build_dataset_group(name, paths, weights, splits, + # args.data_impl, + # train_val_test_num_samples, + # args.seq_length, args.seed, + # (not args.mmap_warmup), + # train_valid_test=s) + # eval(f"{s}_ds").append(d) + # else: + # raise NotImplementedError("No dataloading argument passed") + + print_rank_0("> finished creating GPT datasets ...") + return train_ds, valid_ds, test_ds + +def command_exists(cmd): + result = subprocess.Popen(f'type {cmd}', stdout=subprocess.PIPE, shell=True) + return result.wait() == 0 + +def git_ds_info(): + from deepspeed.env_report import main as ds_report + ds_report() + + # Write out version/git info + git_hash_cmd = "git rev-parse --short HEAD" + git_branch_cmd = "git rev-parse --abbrev-ref HEAD" + if command_exists('git'): + try: + result = subprocess.check_output(git_hash_cmd, shell=True) + git_hash = result.decode('utf-8').strip() + result = subprocess.check_output(git_branch_cmd, shell=True) + git_branch = result.decode('utf-8').strip() + except subprocess.CalledProcessError: + git_hash = "unknown" + git_branch = "unknown" + else: + git_hash = "unknown" + git_branch = "unknown" + print(f'**** Git info for Megatron: git_hash={git_hash} git_branch={git_branch} ****') + + +if __name__ == "__main__": + git_ds_info() + pretrain(train_valid_test_datasets_provider, model_provider, forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'})