diff --git a/benchmarks/compatibility/scripts/megatron/README.md b/benchmarks/compatibility/scripts/megatron/README.md new file mode 100644 index 0000000..054f7dd --- /dev/null +++ b/benchmarks/compatibility/scripts/megatron/README.md @@ -0,0 +1,31 @@ +### How to run Megatron-LM on Ascend platform? + +1. Install MindSpeed by following the [Ascend official guide](https://gitcode.com/Ascend/MindSpeed/blob/master/docs/user-guide/installation.md). Recommend to use Python 3.10.x and PyTorch 2.1.0, some code in Mindspeed may not be compatible with Python 3.9.x. + +2. With MindSpeed Core, you can run Megatron-LM on Ascend training devices by adding just a single line of code. + ##### Example usage + + Using the GPT model as an example, modify the `pretrain_gpt.py` file located in your Megatron-LM directory. Add the new line `import mindspeed.megatron_adaptor` right after import torch. + + See the example modification below: + ``` + import torch + import mindspeed.megatron_adaptor # new code added + from functools import partial + from contextlib import nullcontext + import inspect + ``` + After making this change, you can launch your Megatron-LM training tasks on Ascend devices. + For detailed instructions, please refer to the [Quick Start Guide](https://gitcode.com/Ascend/MindSpeed/blob/master/docs/user-guide/getting_started.md). + + +### How to test this repo Megatron-LM scripts on Ascend platform? +1. With MindSpeed installed, init Ascend related env variable. + ``` + source /usr/local/Ascend/nnal/atb/set_env.sh + source /usr/local/Ascend/ascend-toolkit/set_env.sh + ``` +2. Go to $INFINIPERF_ROOT/benchmarks/compatibility/scripts/megatron folder and run test script by the following command + ``` + bash run_megatron_test.sh + ``` \ No newline at end of file diff --git a/benchmarks/compatibility/scripts/megatron/build.sh b/benchmarks/compatibility/scripts/megatron/build.sh index b2da2cb..76e7588 100644 --- a/benchmarks/compatibility/scripts/megatron/build.sh +++ b/benchmarks/compatibility/scripts/megatron/build.sh @@ -5,17 +5,36 @@ if [ -z "$INFINIPERF_ROOT" ]; then exit 1 fi -export OSCAR_DATA_PATH=/workspace/bolunz/Data/oscar-en-10k.jsonl -export LLAMA2_7B_MODEL_PATH=/workspace/bolunz/Models/Llama-2-7b-hf +export OSCAR_DATA_PATH=/home/libaoming/workplace/oscar-en-10k.jsonl +export LLAMA2_7B_MODEL_PATH=/home/libaoming/workplace/Llama-2-7b-hf export MEGATRON_ROOT=$INFINIPERF_ROOT/benchmarks/compatibility/Megatron-LM +if [ "$INFINIPERF_PLATFORM" = "ASCEND_NPU" ]; then + echo "INFO: ASCEND_NPU platform detected. Change Magetron-LM to core_v0.12.1 branch..." + cd $MEGATRON_ROOT + # Clean change not staged + git reset --hard HEAD + git clean -fd + git checkout main + git pull --rebase + git checkout core_v0.12.1 + cd - + echo "INFO: Magetron-LM repo changed to core_v0.12.1 branch successfully." + + echo "INFO: ASCEND_NPU platform detected. Copying preprocess_data files..." + cp -v $INFINIPERF_ROOT/benchmarks/compatibility/scripts/megatron/preprocess_data.py $MEGATRON_ROOT/tools/ + echo "INFO: preprocess_data file for ASCEND_NPU platform copied successfully." +else + echo "INFO: ASCEND_NPU platform not detected (PLATFORM is '$INFINIPERF_PLATFORM'). Skipping file copy." +fi + cp $INFINIPERF_ROOT/benchmarks/compatibility/scripts/megatron/training.py $MEGATRON_ROOT/megatron/training -cp $MEGATRON_ROOT/pretrain_gpt.py $MEGATRON_ROOT/pretrain_llama.py +cp $INFINIPERF_ROOT/benchmarks/compatibility/scripts/megatron/pretrain_llama.py $MEGATRON_ROOT/pretrain_llama.py if ! python -c "import sentencepiece" &> /dev/null; then echo "sentencepiece 未安装,正在安装..." pip install sentencepiece else echo "sentencepiece 已安装。" -fi \ No newline at end of file +fi diff --git a/benchmarks/compatibility/scripts/megatron/preprocess_data.py b/benchmarks/compatibility/scripts/megatron/preprocess_data.py new file mode 100644 index 0000000..682aab7 --- /dev/null +++ b/benchmarks/compatibility/scripts/megatron/preprocess_data.py @@ -0,0 +1,398 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. + +"""Processing large data for pretraining.""" +import argparse +import math +import json +import os +import sys +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), + os.path.pardir))) +import time +import gzip +import glob +import torch +import mindspeed.megatron_adaptor +import numpy as np +import multiprocessing +try: + import nltk + from nltk.tokenize.punkt import PunktLanguageVars + nltk_available = True +except ImportError: + PunktLanguageVars = object # Fallback to the built-in object class + nltk_available = False + +from megatron.training.tokenizer import build_tokenizer +from megatron.training.arguments import _add_tokenizer_args +from megatron.core.datasets import indexed_dataset + + +# https://stackoverflow.com/questions/33139531/preserve-empty-lines-with-nltks-punkt-tokenizer +class CustomLanguageVars(PunktLanguageVars): + + _period_context_fmt = r""" + \S* # some word material + %(SentEndChars)s # a potential sentence ending + \s* # <-- THIS is what I changed + (?=(?P + %(NonWord)s # either other punctuation + | + (?P\S+) # <-- Normally you would have \s+ here + ))""" + +class IdentitySplitter(object): + def tokenize(self, *text): + return text + + +class Encoder(object): + def __init__(self, args): + self.args = args + + def initializer(self): + # Use Encoder class as a container for global data + Encoder.tokenizer = build_tokenizer(self.args) + if self.args.split_sentences: + if not nltk_available: + print("NLTK is not available to split sentences.") + exit() + if os.environ.get("NLTK_DATA"): + library = os.path.join(os.environ.get("NLTK_DATA"), "tokenizers", "punkt", f"{self.args.lang}.pickle") + url = f"file:{library}" + else: + library = os.path.join("tokenizers", "punkt", f"{self.args.lang}.pickle") + url = f"nltk:{library}" + splitter = nltk.load(url) + if self.args.keep_newlines: + # this prevents punkt from eating newlines after sentences + Encoder.splitter = nltk.tokenize.punkt.PunktSentenceTokenizer( + train_text = splitter._params, + lang_vars = CustomLanguageVars()) + else: + Encoder.splitter = splitter + + else: + Encoder.splitter = IdentitySplitter() + + def split(self, json_line): + data = json.loads(json_line) + output = {} + for key in self.args.json_keys: + text = data[key] + max_len = 1000000 + tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)] + output[key] = [tokens for partial in tokens_list for tokens in partial] + return json.dumps(output), len(json_line) + + def encode(self, json_line): + data = json.loads(json_line) + ids = {} + lens = {} + for key in self.args.json_keys: + text = data[key] + if isinstance(text, list): + sentences = text + else: + sentences = [text] + doc_ids = [] + sentence_lens = [] + for sentence in sentences: + sentence_ids = Encoder.tokenizer.tokenize(sentence) + if len(sentence_ids) > 0: + doc_ids.extend(sentence_ids) + sentence_lens.append(len(sentence_ids)) + if len(doc_ids) > 0 and self.args.append_eod: + doc_ids.append(Encoder.tokenizer.eod) + sentence_lens[-1] += 1 + ids[key] = doc_ids + lens[key] = sentence_lens + return ids, lens, len(json_line) + + +class Partition(object): + def __init__(self, args, workers): + self.args = args + self.workers = workers + + def print_processing_stats(self, count, proc_start, total_bytes_processed): + if count % self.args.log_interval == 0: + current = time.time() + elapsed = current - proc_start + mbs = total_bytes_processed/elapsed/1024/1024 + print(f"Processed {count} documents", + f"({count/elapsed} docs/s, {mbs} MB/s).", + file=sys.stderr) + + def split_sentences(self, file_name): + input_file_name, output_file_name = file_name + print("Opening", input_file_name) + fin = open(input_file_name, 'r', encoding='utf-8') + fout = open(output_file_name, 'w') + + encoder = Encoder(self.args) + pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) + split_docs = pool.imap(encoder.split, fin, 32) + + proc_start = time.time() + total_bytes_processed = 0 + for i, (doc, bytes_processed) in enumerate(split_docs, start=1): + total_bytes_processed += bytes_processed + fout.write(doc + "\n") + self.print_processing_stats(i, proc_start, total_bytes_processed) + + fin.close() + fout.close() + + + def process_json_file(self, file_name): + input_file_name, output_prefix = file_name + print("Opening", input_file_name) + fin = open(input_file_name, 'r', encoding='utf-8') + + startup_start = time.time() + encoder = Encoder(self.args) + tokenizer = build_tokenizer(self.args) + pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer) + encoded_docs = pool.imap(encoder.encode, fin, 32) + + level = "document" + if self.args.split_sentences: + level = "sentence" + + output_bin_files = {} + output_idx_files = {} + builders = {} + + for key in self.args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix, + key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix, + key, level) + builders[key] = indexed_dataset.IndexedDatasetBuilder( + output_bin_files[key], + dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), + ) + + startup_end = time.time() + proc_start = time.time() + total_bytes_processed = 0 + print("Time to startup:", startup_end - startup_start) + for i, (doc, sentence_lens, bytes_processed) in enumerate(encoded_docs, start=1): + total_bytes_processed += bytes_processed + for key in doc.keys(): + builders[key].add_document(doc[key], sentence_lens[key]) + self.print_processing_stats(i, proc_start, total_bytes_processed) + + fin.close() + builders[key].finalize(output_idx_files[key]) + + +def get_args(): + parser = argparse.ArgumentParser() + parser = _add_tokenizer_args(parser) + group = parser.add_argument_group(title='input data') + group.add_argument('--input', type=str, required=True, + help='Path to input JSON') + group.add_argument('--json-keys', nargs='+', default=['text'], + help='space separate listed of keys to extract from json') + group.add_argument('--split-sentences', action='store_true', + help='Split documents into sentences.') + group.add_argument('--keep-newlines', action='store_true', + help='Keep newlines between sentences when splitting.') + group = parser.add_argument_group(title='tokenization process') + group.add_argument('--append-eod', action='store_true', + help='Append an token to the end of a document.') + group.add_argument('--lang', type=str, default='english', + help='Language to use for NLTK-powered sentence splitting.') + group = parser.add_argument_group(title='output data') + group.add_argument('--output-prefix', type=str, required=True, + help='Path to binary output file without suffix') + group = parser.add_argument_group(title='runtime') + group.add_argument('--workers', type=int, required=True, + help=('Number of worker processes to launch.' + 'A good default for fast pre-processing ' + 'is: (workers * partitions) = available CPU cores.')) + group.add_argument('--partitions', type=int, default=1, + help='Number of file partitions') + group.add_argument('--log-interval', type=int, default=1000, + help='Interval between progress updates') + group.add_argument('--keep-sequential-samples', action='store_true', + help='Ensure ordering of samples in .jsonl files is ' + 'preserved when using partitions>1.') + args = parser.parse_args() + args.keep_empty = False + + if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences: + print("Are you sure you don't want to split sentences?") + + # some default/dummy values for the tokenizer + args.rank = 1 + args.make_vocab_size_divisible_by = 128 + args.tensor_model_parallel_size = 1 + args.vocab_extra_ids = 0 + + return args + + +def get_file_name(args, file_id): + file_name, extension = os.path.splitext(args.input) + input_file_name = file_name + "_" + str(file_id) + extension + sentence_split_file = file_name + "_ss_" + str(file_id) + extension + output_prefix = args.output_prefix + "_" + str(file_id) + file_names = { + 'partition': input_file_name, + 'sentence_split': sentence_split_file, + 'output_prefix': output_prefix} + return file_names + + +def check_files_exist(in_ss_out_names, key, num_partitions): + for i in range(num_partitions): + if not os.path.exists(in_ss_out_names[i][key]): + return False + return True + + +def main(): + args = get_args() + + if args.split_sentences: + if nltk_available: + nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA")) + else: + raise Exception( + "nltk library required for sentence splitting is not available.") + + in_ss_out_names = [] + if args.partitions == 1: + file_name, extension = os.path.splitext(args.input) + sentence_split_file = file_name + "_ss" + extension + file_names = { + 'partition': args.input, + 'sentence_split': sentence_split_file, + 'output_prefix': args.output_prefix} + in_ss_out_names.append(file_names) + else: + in_file_names = glob.glob(args.input) + + # Count total number of lines across .jsonl files + if args.keep_sequential_samples: + total_sample_count = 0 + for filename in in_file_names: + with open(filename, "r") as fin: + for fc, _ in enumerate(fin): + pass + total_sample_count += (fc + 1) + partition_size = math.ceil(total_sample_count / args.partitions) + + # create .jsonl parition files + for idx in range(args.partitions): + in_ss_out_name = get_file_name(args, idx) + in_ss_out_names.append(in_ss_out_name) + + # check to see if paritions were already created + partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions) + + # check to see if paritions with split sentences already created + split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + + if not partitions_present and not split_sentences_present: + # populate .jsonl partition files from parent files + partitioned_input_files = [] + for idx in range(args.partitions): + partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w') + partitioned_input_files.append(partitioned_input_file) + + index = 0 + if args.keep_sequential_samples: line_count = 0 + for in_file_name in in_file_names: + # support for gzip files + if in_file_name.endswith(".gz"): + fin = gzip.open(in_file_name, 'rt') + else: + fin = open(in_file_name, 'r', encoding='utf-8') + + for line in fin: + partitioned_input_files[index].write(line) + if args.keep_sequential_samples: + line_count += 1 + if line_count % partition_size == 0: + index += 1 + else: + index = (index + 1)%args.partitions + + fin.close() + + for idx in range(args.partitions): + partitioned_input_files[idx].close() + + assert args.workers % args.partitions == 0 + partition = Partition(args, args.workers//args.partitions) + + # check to see if paritions with split sentences already created + split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions) + + # split sentences in partition files + if args.split_sentences and not split_sentences_present: + processes = [] + for name in in_ss_out_names: + p = multiprocessing.Process(target=partition.split_sentences, + args=((name['partition'], name['sentence_split']),)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if args.partitions == 1: + return + + + # encode partition files in parallel + processes = [] + input_key = 'sentence_split' if args.split_sentences else 'partition' + for name in in_ss_out_names: + p = multiprocessing.Process(target=partition.process_json_file, + args=((name[input_key], name['output_prefix']),)) + p.start() + processes.append(p) + + for p in processes: + p.join() + + if args.partitions == 1: + return + + # merge bin/idx partitions + level = "document" + if args.split_sentences: + level = "sentence" + + output_bin_files = {} + output_idx_files = {} + builders = {} + tokenizer = build_tokenizer(args) + + for key in args.json_keys: + output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, + key, level) + output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, + key, level) + builders[key] = indexed_dataset.IndexedDatasetBuilder( + output_bin_files[key], + dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size), + ) + + for name in in_ss_out_names: + parition_output_prefix = name['output_prefix'] + full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix, + key, level) + builders[key].add_index(full_partition_output_prefix) + builders[key].finalize(output_idx_files[key]) + + +if __name__ == '__main__': + + main() + diff --git a/benchmarks/compatibility/scripts/megatron/pretrain_llama.py b/benchmarks/compatibility/scripts/megatron/pretrain_llama.py new file mode 100644 index 0000000..36e35ec --- /dev/null +++ b/benchmarks/compatibility/scripts/megatron/pretrain_llama.py @@ -0,0 +1,334 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +"""Pretrain GPT.""" +import os +import torch +if os.getenv('INFINIPERF_PLATFORM') == 'ASCEND_NPU': + try: + import mindspeed.megatron_adaptor + print("INFO: MindSpeed adaptor for ASCEND_NPU platform has been enabled.") + except ImportError: + print("WARNING: PLATFORM is set to 'ASCEND_NPU', but the mindspeed package could not be found.") + +from functools import partial +from contextlib import nullcontext +import inspect + +from typing import List, Optional, Tuple, Union +from megatron.training import get_args +from megatron.training import print_rank_0 +from megatron.training import get_timers +from megatron.training import get_tokenizer +from megatron.core import mpu +from megatron.core.enums import ModelType +from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder +from megatron.core.datasets.gpt_dataset import GPTDatasetConfig +from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset +from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import ( + get_gpt_heterogeneous_layer_spec, +) +from megatron.core.rerun_state_machine import get_rerun_state_machine +import megatron.legacy.model +from megatron.core.models.gpt import GPTModel +from megatron.training import pretrain +from megatron.core.utils import StragglerDetector +from megatron.core.transformer.spec_utils import import_module +from megatron.training.utils import ( + get_batch_on_this_cp_rank, + get_batch_on_this_tp_rank, + get_blend_and_blend_per_split, +) +from megatron.training.arguments import core_transformer_config_from_args +from megatron.training.yaml_arguments import core_transformer_config_from_yaml +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_decoder_block_spec, + get_gpt_layer_local_spec, + get_gpt_layer_with_transformer_engine_spec, + get_gpt_mtp_block_spec, +) +from megatron.core.transformer.transformer_block import TransformerBlockSubmodules + + +stimer = StragglerDetector() + +def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]: + """Builds the model. + + If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. + + Args: + pre_process (bool, optional): Set to true if you need to compute embedings. Defaults to True. + post_process (bool, optional): Set to true if you need to want to compute output logits/loss. Defaults to True. + + + Returns: + Union[GPTModel, megatron.legacy.model.GPTModel]: The returned model + """ + args = get_args() + use_te = args.transformer_impl == "transformer_engine" + + if args.record_memory_history: + torch.cuda.memory._record_memory_history(True, + # keep 100,000 alloc/free events from before the snapshot + trace_alloc_max_entries=100000, + + # record stack information for the trace events + trace_alloc_record_context=True) + + def oom_observer(device, alloc, device_alloc, device_free): + # snapshot right after an OOM happened + print('saving allocated state during OOM') + snapshot = torch.cuda.memory._snapshot() + from pickle import dump + dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb')) + + torch._C._cuda_attach_out_of_memory_observer(oom_observer) + + print_rank_0('building GPT model ...') + # Experimental loading arguments from yaml + if args.yaml_cfg is not None: + config = core_transformer_config_from_yaml(args, "language_model") + else: + config = core_transformer_config_from_args(args) + + if args.use_legacy_models: + model = megatron.legacy.model.GPTModel( + config, + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + ) + else: # using core models + if args.spec is not None: + transformer_layer_spec = import_module(args.spec) + else: + if args.num_experts: + # Define the decoder block spec + transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te, normalization=args.normalization) + elif args.heterogeneous_layers_config_path is not None: + transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te) + else: + # Define the decoder layer spec + if use_te: + transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) + else: + transformer_layer_spec = get_gpt_layer_local_spec( + args.num_experts, args.moe_grouped_gemm, + args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm, + normalization=args.normalization) + mtp_block_spec = None + if args.mtp_num_layers is not None: + mtp_block_spec = get_gpt_mtp_block_spec(config, transformer_layer_spec, use_transformer_engine=use_te) + + model = GPTModel( + config=config, + transformer_layer_spec=transformer_layer_spec, + vocab_size=args.padded_vocab_size, + max_sequence_length=args.max_position_embeddings, + pre_process=pre_process, + post_process=post_process, + fp16_lm_cross_entropy=args.fp16_lm_cross_entropy, + parallel_output=True, + share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, + position_embedding_type=args.position_embedding_type, + rotary_percent=args.rotary_percent, + rotary_base=args.rotary_base, + rope_scaling=args.use_rope_scaling, + mtp_block_spec=mtp_block_spec, + ) + + return model + + +def get_batch(data_iterator): + """Generate a batch.""" + + # TODO: this is pretty hacky, find a better way + if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): + return None, None, None, None, None + + # get batches based on the TP rank you are on + batch = get_batch_on_this_tp_rank(data_iterator) + + # slice batch along sequence dimension for context parallelism + batch = get_batch_on_this_cp_rank(batch) + + return batch.values() + + +# define spiky loss as a loss that's 10x the max loss observed +SPIKY_LOSS_FACTOR = 10 + + +def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): + """Loss function. + + Args: + loss_mask (torch.Tensor): Used to mask out some portions of the loss + output_tensor (torch.Tensor): The tensor with the losses + + Returns: + the loss scalar for this micro-batch + the number of non-padded tokens in this microbatch + a dict containing reporting metrics on the loss and number of tokens across + the data parallel ranks + """ + args = get_args() + + losses = output_tensor.float() + loss_mask = loss_mask.view(-1).float() + total_tokens = loss_mask.sum() + loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) + + if args.context_parallel_size > 1: + torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) + + # Check individual rank losses are not NaN prior to DP all-reduce. + rerun_state_machine = get_rerun_state_machine() + if args.check_for_nan_in_loss_and_grad: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isnan, + message="found NaN in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=torch.isinf, + message="found Inf in local forward loss calculation", + tolerance=0.0, # forward pass calculations are determinisic + fatal=True, + ) + # Check for spiky loss + if args.check_for_spiky_loss: + rerun_state_machine.validate_result( + result=loss[0], + rejection_func=partial( + rerun_state_machine.is_unexpectedly_large, + threshold=SPIKY_LOSS_FACTOR, + context="loss", + ), + message="Spiky loss", + tolerance=0.0, # forward pass calculations are determinisic + fatal=False, + ) + # Reduce loss for logging. + reporting_loss = loss.clone().detach() + torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) + + # loss[0] is a view of loss, so it has ._base not None, which triggers assert error + # in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone() + # on loss[0] fixes this + local_num_tokens = loss[1].clone().detach().to(torch.int) + return ( + loss[0].clone(), + local_num_tokens, + {'lm loss': (reporting_loss[0], reporting_loss[1])}, + ) + + +def forward_step(data_iterator, model: GPTModel): + """Forward training step. + + Args: + data_iterator : Input data iterator + model (GPTModel): The GPT Model + """ + args = get_args() + timers = get_timers() + + # Get the batch. + timers('batch-generator', log_level=2).start() + global stimer + with stimer(bdata=True): + tokens, labels, loss_mask, attention_mask, position_ids = get_batch( + data_iterator) + timers('batch-generator').stop() + + with stimer: + if args.use_legacy_models: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels) + else: + output_tensor = model(tokens, position_ids, attention_mask, + labels=labels, loss_mask=loss_mask) + + return output_tensor, partial(loss_func, loss_mask) + + +def is_dataset_built_on_rank(): + return ( + mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() + ) and mpu.get_tensor_model_parallel_rank() == 0 + + +def core_gpt_dataset_config_from_args(args): + tokenizer = get_tokenizer() + + # Sometimes --data-path is too long, instead we parse it from a file. + blend: Optional[Tuple[List[str], Optional[List[float]]]] + blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]] + blend, blend_per_split = get_blend_and_blend_per_split(args) + + return GPTDatasetConfig( + random_seed=args.seed, + sequence_length=args.seq_length, + blend=blend, + blend_per_split=blend_per_split, + split=args.split, + num_dataset_builder_threads=args.num_dataset_builder_threads, + path_to_cache=args.data_cache_path, + mmap_bin_files=args.mmap_bin_files, + tokenizer=tokenizer, + reset_position_ids=args.reset_position_ids, + reset_attention_mask=args.reset_attention_mask, + eod_mask_loss=args.eod_mask_loss, + create_attention_mask=args.create_attention_mask_in_dataloader, + s3_cache_path=args.s3_cache_path, + ) + + +def train_valid_test_datasets_provider(train_val_test_num_samples): + """Build the train test and validation datasets. + + Args: + train_val_test_num_samples : A list containing the number of samples in train test and validation. + """ + args = get_args() + + config = core_gpt_dataset_config_from_args(args) + + if args.mock_data: + dataset_type = MockGPTDataset + else: + dataset_type = GPTDataset + + print_rank_0("> building train, validation, and test datasets for GPT ...") + + train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( + dataset_type, + train_val_test_num_samples, + is_dataset_built_on_rank, + config + ).build() + + print_rank_0("> finished creating GPT datasets ...") + + return train_ds, valid_ds, test_ds + + +if __name__ == "__main__": + + # Temporary for transition to core datasets + train_valid_test_datasets_provider.is_distributed = True + + pretrain( + train_valid_test_datasets_provider, + model_provider, + ModelType.encoder_or_decoder, + forward_step, + args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, + ) \ No newline at end of file diff --git a/benchmarks/compatibility/scripts/megatron/pretrain_llama.sh b/benchmarks/compatibility/scripts/megatron/pretrain_llama.sh index 2bb7712..ce311c8 100644 --- a/benchmarks/compatibility/scripts/megatron/pretrain_llama.sh +++ b/benchmarks/compatibility/scripts/megatron/pretrain_llama.sh @@ -155,7 +155,7 @@ CHECKPOINTING_ARGS=" \ --no-load-optim \ --no-load-rng \ --save ${SAVE_PATH} \ - --save-interval 10 \ + --save-interval 100 \ " MIXED_PRECISION_ARGS=" \ @@ -193,4 +193,4 @@ CMD="${LAUNCHER} \ ${MOE_ARGS} \ " echo ${CMD} -${CMD} 2>&1 | tee ${LOG_PATH} \ No newline at end of file +${CMD} 2>&1 | tee ${LOG_PATH} diff --git a/benchmarks/compatibility/scripts/megatron/training.py b/benchmarks/compatibility/scripts/megatron/training.py index 16af31f..320659e 100644 --- a/benchmarks/compatibility/scripts/megatron/training.py +++ b/benchmarks/compatibility/scripts/megatron/training.py @@ -21,6 +21,12 @@ # The earliest we can measure the start time. _TRAIN_START_TIME = time.time() import torch +if os.getenv('INFINIPERF_PLATFORM') == 'ASCEND_NPU': + try: + import mindspeed.megatron_adaptor + print("INFO: MindSpeed adaptor for ASCEND_NPU platform has been enabled.") + except ImportError: + print("WARNING: PLATFORM is set to 'ASCEND_NPU', but the mindspeed package could not be found.") from megatron.core import mpu, tensor_parallel from megatron.core.utils import ( @@ -33,7 +39,12 @@ from megatron.training.checkpointing import load_checkpoint from megatron.training.checkpointing import save_checkpoint from megatron.training.checkpointing import checkpoint_exists -from megatron.legacy.model import Float16Module +if os.getenv('INFINIPERF_PLATFORM') == 'ASCEND_NPU': + from megatron.core.transformer.module import Float16Module + print("INFO: MindSpeed adaptor for ASCEND_NPU platform has been enabled.") +else: + from megatron.legacy.model import Float16Module + from megatron.core.distributed import DistributedDataParallelConfig from megatron.core.distributed import DistributedDataParallel as DDP from megatron.core.distributed.custom_fsdp import FullyShardedDataParallel as custom_FSDP @@ -897,7 +908,11 @@ def build_model(): # Fp16 conversion. if args.fp16 or args.bf16: - model = [Float16Module(model_module, args) for model_module in model] + if os.getenv('INFINIPERF_PLATFORM') == 'ASCEND_NPU': + config = get_model_config(model[0]) + model = [Float16Module(config, model_module) for model_module in model] + else: + model = [Float16Module(model_module, args) for model_module in model] # The model_module.bfloat16()/model_module.half() above will call the inplace copy of TE's # Float8Tensor, which will write an unwanted value (amax calculated from the current fp8