Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2a019a1
Update megatron version
jlamypoirier Dec 21, 2023
1ac3039
Merge branch 'main' into compare_tensors_updated
jlamypoirier Dec 21, 2023
40f251c
fixes
jlamypoirier Dec 21, 2023
efb3c3a
profiling
jlamypoirier Jan 9, 2024
716204e
misc
jlamypoirier Jan 19, 2024
1e0e58e
Merge branch 'main' into compare_tensors_updated
jlamypoirier Jan 24, 2024
98fbb42
Fix
jlamypoirier Jan 24, 2024
0bfeeae
rename output layer
maxmatical Jan 30, 2024
bb53cf9
Merge pull request #3 from ServiceNow/max/rename-output-layer
maxmatical Jan 30, 2024
9760e11
Tokenizer fix
jlamypoirier Feb 6, 2024
94ce57b
Merge remote-tracking branch 'nvidia/main' into compare_tensors_updated
jlamypoirier Feb 6, 2024
b22634d
fix
jlamypoirier Feb 6, 2024
2165919
Better wandb
jlamypoirier Feb 7, 2024
c478f48
misc
jlamypoirier Feb 7, 2024
5566742
Merge branch 'main' into compare_tensors_updated
jlamypoirier Feb 13, 2024
63d9d3e
MOE support
jlamypoirier Mar 8, 2024
40a134a
stuff
jlamypoirier Mar 8, 2024
1a96a99
Merge branch 'main' into compare_tensors_updated
jlamypoirier Mar 8, 2024
fdd668c
Support megatron core models
jlamypoirier Mar 11, 2024
4238a80
Fix arg
jlamypoirier Mar 11, 2024
fe38434
fixes
jlamypoirier Mar 12, 2024
3c6652e
fix
jlamypoirier May 29, 2024
f6b9b4b
fix
jlamypoirier Sep 19, 2024
cb6baf1
update
jlamypoirier Feb 11, 2025
fe1f23c
misc
jlamypoirier Mar 5, 2025
2e23b9b
Revert "misc"
jlamypoirier May 16, 2025
511e8f5
version
jlamypoirier May 16, 2025
75b0d97
fix
jlamypoirier Jul 23, 2025
f02b413
misc
jlamypoirier Aug 14, 2025
89f391e
stuff
jlamypoirier Sep 3, 2025
30e7aec
fix
jlamypoirier Sep 9, 2025
dee2745
misc
jlamypoirier Oct 2, 2025
67d0694
misc
jlamypoirier Mar 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,16 @@

import torch.nn.functional as F
from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path

from megatron.core.models.retro import RetroConfig
from megatron.core.transformer import TransformerConfig


def get_args_path(workdir):
'''Argument copy stored within retro workdir.'''
return os.path.join(workdir, "args.json")


def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
Expand Down Expand Up @@ -213,6 +217,9 @@ def validate_args(args, defaults={}):
if args.dataloader_type is None:
args.dataloader_type = 'single'

if args.valid_num_workers is None:
args.valid_num_workers = args.num_workers

# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
Expand Down Expand Up @@ -364,6 +371,9 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False

if not args.use_flash_attn:
assert args.window_size is None

if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
Expand Down Expand Up @@ -640,6 +650,8 @@ def _add_network_size_args(parser):
'Deprecated: use --position-embedding-type')
group.add_argument('--rotary-percent', type=float, default=1.0,
help='Percent of rotary dimension to use, default 100%%')
group.add_argument('--rotary-theta', type=int, default=10000,
help='Theta/frequency value for rotary positional embeddings')
group.add_argument('--rotary-interleaved', action='store_true',
help='Use interleaved rotary embedding.')
group.add_argument('--rotary-seq-len-interpolation-factor', type=int, default=None,
Expand Down Expand Up @@ -679,6 +691,7 @@ def _add_network_size_args(parser):
dest='bert_binary_head')
group.add_argument('--untie-embeddings-and-output-weights', action='store_true',
help='Untie embeddings and output weights.'),
group.add_argument('--window-size', type=int, default=None)
return parser


Expand Down Expand Up @@ -749,12 +762,25 @@ def _add_logging_args(parser):
group.add_argument('--log-world-size-to-tensorboard',
action='store_true',
help='Enable world size logging to tensorboard.')
group.add_argument('--wandb-project', type=str, default='',
group.add_argument('--wandb-project', '--wandb-project-name', type=str, default='',
help='The wandb project name. Ignore wandb by default.')
group.add_argument('--wandb-exp-name', type=str, default='',
help='The wandb experiment name.')
group.add_argument('--wandb-save-dir', type=str, default='',
help='Path to save the wandb results locally.')
group.add_argument('--wandb-group-name', type=str, default="default")
group.add_argument('--wandb-entity-name', type=str, default=None,
help="Name of wandb entity for reporting")
group.add_argument('--structured-logs', action="store_true",
help='Add timestamp and worker name to stdout and stderr.')
group.add_argument('--structured-logs-dir', type=str, default=None,
help='Directory to save the logs.')
group.add_argument('--debug_layer_outputs', '--debug-layer-outputs', type=int, default=0)
group.add_argument('--debug_layer_gradients', '--debug-layer-gradients', type=int, default=0)
group.add_argument('--debug_all_param_gradients', '--debug-all-param-gradients', type=int, default=0)
group.add_argument('--debug_param_init', '--debug-param-init', type=int, default=0)
group.add_argument('--debug_param_update', '--debug-param-update', type=int, default=0)
group.add_argument('--debug_transformer', '--debug-transformer', type=int, default=0)
group.add_argument('--enable-one-logger', action='store_true',
help='If set, use one_logger to track E2E metrics'
'Note that one_logger is an internal tool and not available externally. '
Expand Down Expand Up @@ -883,6 +909,7 @@ def _add_training_args(parser):
help='Global step to stop profiling.')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.')
group.add_argument('--torch-profile-dir', type=str, default=None)
group.add_argument('--tp-comm-overlap', action='store_true', help = 'Enables the '
' overlap of Tensor parallel communication and GEMM kernels.')
group.add_argument('--tp-comm-overlap-cfg', type=str, default=None,
Expand Down Expand Up @@ -1272,6 +1299,8 @@ def _add_data_args(parser):
help='Path to the vocab file.')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file.')
group.add_argument('--tokenizer-file', type=str, default=None,
help='Path to the tokenizer.json file. Used for the TokenizerFromFile[...] tokenizers')
group.add_argument('--vocab-extra-ids', type=int, default=0,
help='Number of additional vocabulary tokens. '
'They are used for span masking in the T5 model')
Expand All @@ -1294,11 +1323,16 @@ def _add_data_args(parser):
help='Probability of producing a short sequence.')
group.add_argument('--num-workers', type=int, default=2,
help="Dataloader number of workers.")
group.add_argument('--valid-num-workers', type=int, default=None,
help="Dataloader number of workers for validation.")
group.add_argument('--tokenizer-type', type=str,
default=None,
choices=['BertWordPieceLowerCase',
'BertWordPieceCase',
'GPT2BPETokenizer',
'GPT2BPETokenizerWithFIM',
'TokenizerFromFile',
'TokenizerFromFileWithFIM',
'SentencePieceTokenizer',
'GPTSentencePieceTokenizer',
'Llama2Tokenizer',
Expand All @@ -1313,7 +1347,15 @@ def _add_data_args(parser):
'end-of-document token.')
group.add_argument('--eod-mask-loss', action='store_true',
help='Mask loss for the end of document tokens.')

group.add_argument('--fim-rate', type=float, default=0.,
help='Probability to convert a training sample into a "Fill-in-the-Middle" format. Must be between 0 and 1.')
group.add_argument('--fim-spm-rate', type=float, default=0.5,
help='Probability that the a FIM sample uses the SPM format over the PSM format. '
'At 1, exclusively train with SPM. At 0, exclusively train with PSM')
group.add_argument('--fim-split-sample', type=str, default=None,
help='String around which to split the sample for FIM. If None (default), FIM is applied on the sample-level')
group.add_argument('--fragment-fim-rate', type=float, default=0.5,
help='Rate of FIM on each fragment when fim_split_sample is not None.')
return parser


Expand Down
2 changes: 1 addition & 1 deletion megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def _compare(arg_name, old_arg_name=None, default=None):
_compare('hidden_size')
_compare('num_attention_heads')
_compare('add_position_embedding', default=True)
if args.vocab_file:
if args.vocab_file or args.tokenizer_file:
_compare('max_position_embeddings')
_compare('make_vocab_size_divisible_by')
_compare('padded_vocab_size')
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/datasets/blended_megatron_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def _build_megatron_dataset_splits(
# Build the mid level dataset
mid_level_datasets = []
for i, _split in enumerate(Split):
if not self.config.mock and split[i] is None:
if not self.config.mock and (split[i] is None or sizes[i] == 0):
mid_level_datasets.append(None)
else:
mid_level_datasets.append(
Expand Down
178 changes: 177 additions & 1 deletion megatron/core/datasets/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import numpy
import torch

from megatron import get_args, get_tokenizer
from megatron.core.datasets.blended_megatron_dataset_config import BlendedMegatronDatasetConfig
from megatron.core.datasets.indexed_dataset import MMapIndexedDataset
from megatron.core.datasets.megatron_dataset import MegatronDataset, MockDataset
from megatron.core.datasets.utils import Split, log_single_rank
from megatron.tokenizer.tokenizer import FIM_MIDDLE, FIM_PAD, FIM_PREFIX, FIM_SUFFIX

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -131,6 +133,21 @@ def __init__(
super().__init__(
indexed_dataset, dataset_path, indexed_indices, num_samples, index_split, config
)
self.args = get_args()
self.tokenizer = get_tokenizer()
self.np_rng = numpy.random.RandomState(seed=self.config.random_seed) # rng state for FIM

self.use_fim = self.args.fim_rate!=0
if self.use_fim:
self.fim_rate = self.args.fim_rate
self.fim_spm_rate = self.args.fim_spm_rate
self.fragment_fim_rate = self.args.fragment_fim_rate
self.fim_split_sample = self.tokenizer.vocab[self.args.fim_split_sample] if self.args.fim_split_sample is not None else None

try:
self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id = (self.tokenizer.special_tokens[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD])
except KeyError:
self.suffix_tok_id, self.prefix_tok_id, self.middle_tok_id, self.pad_tok_id = (self.tokenizer.vocab[tok] for tok in [FIM_SUFFIX, FIM_PREFIX, FIM_MIDDLE, FIM_PAD])

self.vocab_size = config.vocab_size

Expand Down Expand Up @@ -265,8 +282,101 @@ def _query_document_sample_shuffle_indices(
self.dataset.get(self.document_index[i], offset=offset, length=length)
)

sample=numpy.concatenate(sample_parts)

# Code from: https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L109
# TODO(Hailey): can merge the code below this line with code above this line.
# TODO(Hailey), cont: above already iterates through loop, so just add the permuting in there?
sample = numpy.array(sample, dtype=numpy.int64)
sample_len = sample.shape[0]
# # print(sample, sample.shape)
# # do FIM here, if enabled
# TODO: Do we handle the following point from FIM paper?
# To transform data in the character space for context-level FIM, the tokenized documents have to be decoded back into strings before FIM augmentation. Depending on the vocabulary, some care has to be given to ensure decoding does not introduce any spurious characters into training. For example, utf-8 characters are encoded as multiple tokens with a BPE vocabulary; they can result in fragments from chunking and fail to decode. To prevent unforeseen errors midway through training, we encourage checking for these fragments at the beginning or end of a context and removing them.
eod = self.tokenizer.eod
segment_breaks = numpy.argwhere(sample == eod) # split sample by document

if not self.use_fim:
return (
numpy.array(sample, dtype=numpy.int64),
numpy.array(document_ids, dtype=numpy.int64),
)

def fim_permute_sequence(sequence, rate):
return permute(
sequence,
self.np_rng,
rate,
self.fim_spm_rate,
self.tokenizer,
truncate_or_pad=False,
suffix_tok_id=self.suffix_tok_id,
prefix_tok_id=self.prefix_tok_id,
middle_tok_id=self.middle_tok_id,
pad_tok_id=self.pad_tok_id,
)

def fim_split_and_permute_sequence(sequence):
"""
If self.fim_split_sample is not None, split the sequence.
Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None.
"""
if self.fim_split_sample is None:
return fim_permute_sequence(sequence, self.fim_rate)
# fim_split_sample is set: split the sample on this token and permute each fragment separately.
# Typically, if each sample is a repository, then we split again on the file level.
# Each fragment is a file, and we permute the files.
fragment_breaks = numpy.argwhere(sequence == self.fim_split_sample)
if fragment_breaks.shape == (0, 1):
# no split token in this sample
return fim_permute_sequence(sequence, self.fim_rate)
if not self.np_rng.binomial(1, self.fim_rate):
# don't do FIM preproc
return sequence
# Do FIM on each fragment
curr_start_position = 0
new_samples = []
for loc in numpy.nditer(fragment_breaks):
if loc - curr_start_position > 0:
permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate)
new_samples += [permuted, [self.fim_split_sample]]
curr_start_position = loc + 1 # Jump over the split token
# Permute the segment after the last split token
permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate)
new_samples.append(permuted)
return numpy.concatenate(new_samples)

if segment_breaks.shape != (0, 1): # then there is an EOD token in this example
curr_start_position = 0
new_samples = []
for loc in numpy.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc])
new_samples += [permuted, [eod]]

curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = fim_split_and_permute_sequence(sample[curr_start_position:])
new_samples.append(permuted)

sample = numpy.concatenate(new_samples)
else:
sample = fim_split_and_permute_sequence(sample)

# Truncate or pad sequence to max-length
diff = sample.shape[0] - sample_len
if diff > 0: # too long
sample = sample[:sample_len]
elif diff < 0: # too short
sample = numpy.concatenate([sample, numpy.full((-1 * diff), self.pad_tok_id)])

assert sample.shape[0] == sample_len
# end FIM-specific code

return (
numpy.array(numpy.concatenate(sample_parts), dtype=numpy.int64),
numpy.array(sample, dtype=numpy.int64),
numpy.array(document_ids, dtype=numpy.int64),
)

Expand Down Expand Up @@ -573,6 +683,71 @@ def _build_shuffle_index(
return numpy.concatenate((shuffle_idx_first, shuffle_idx_last))


# From https://github.com/EleutherAI/gpt-neox/blob/FIM-clean/megatron/data/gpt2_dataset.py#L339
def permute(sample, np_rng, fim_rate, fim_spm_rate, tokenizer, truncate_or_pad=True,
suffix_tok_id=None, prefix_tok_id=None, middle_tok_id=None, pad_tok_id=None):
"""
Take in a sample (np array w/ size (0,chunklength)) and perform a FIM transformation on it.
Maintain the same sample length (if transform creates a few extra tokens, drop them).
"""
if np_rng.binomial(1, fim_rate): # sample bernoulli dist

contents = tokenizer.detokenize(sample)

try:
# A boundary can be =0 (prefix will be empty)
# a boundary can be =len(contents) (suffix will be empty)
# The two boundaries can be equal (middle will be empty)
boundaries = list(np_rng.randint(low=0, high=len(contents) + 1, size=2))
boundaries.sort()
except ValueError as e:
print(len(contents), contents)
print(e)
raise e

prefix = contents[:boundaries[0]]
middle = contents[boundaries[0]:boundaries[1]]
suffix = contents[boundaries[1]:]

prefix = numpy.array([*tokenizer.tokenize(prefix)], dtype=numpy.int64)
middle = numpy.array([*tokenizer.tokenize(middle)], dtype=numpy.int64)
suffix = numpy.array([*tokenizer.tokenize(suffix)], dtype=numpy.int64)

# here we truncate each given segment to fit the same length as it was before
# A consequence is that we never reach the end of a file?
# we should rather truncate at the context-level
if truncate_or_pad:
# need to make same length as the input. Take the 3 sentinel tokens into account
new_length = suffix.shape[0] + prefix.shape[0] + middle.shape[0] + 3
diff = new_length - sample.shape[0]
if diff > 0: # too long
if suffix.shape[0] <= diff: # if there's no space to truncate the suffix: stop and report it. atm i should have stopped this from happening
return sample, np_rng
suffix = suffix[:suffix.shape[0] - diff]
elif diff < 0: # too short
suffix = numpy.concatenate([suffix, numpy.full((-1 * diff), pad_tok_id)])

if np_rng.binomial(1, fim_spm_rate):
# SPM (variant 2 from FIM paper)
new_sample = numpy.concatenate([
[prefix_tok_id, suffix_tok_id], suffix,
[middle_tok_id], prefix, middle
])
else:
# PSM
new_sample = numpy.concatenate([
[prefix_tok_id], prefix,
[suffix_tok_id], suffix,
[middle_tok_id], middle
])

else:
# don't do FIM preproc
new_sample = sample

return new_sample


def _get_ltor_masks_and_position_ids(
data: torch.Tensor,
eod_token: int,
Expand Down Expand Up @@ -640,3 +815,4 @@ def _get_ltor_masks_and_position_ids(
attention_mask = attention_mask < 0.5

return attention_mask, loss_mask, position_ids

7 changes: 6 additions & 1 deletion megatron/core/datasets/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,12 @@ def __init__(self, idx_path: str, multimodal: bool) -> None:
assert header == _INDEX_HEADER, f"bad header, cannot read: {idx_path}"

version = struct.unpack("<Q", stream.read(8))[0]
assert version == 1, f"bad version, cannot read: {idx_path}"
assert version in (1, 2, 3), f"bad version, cannot read: {idx_path}"

if version >= 2:
_ = struct.unpack("<B", stream.read(1))[0]
if version >= 3:
_ = struct.unpack("<B", stream.read(1))[0]

code = struct.unpack("<B", stream.read(1))[0]
self.dtype = DType.dtype_from_code(code)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/datasets/megatron_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.unique_identifiers[attr] = getattr(self.config, attr)

self.unique_description = json.dumps(
self.unique_identifiers, indent=4, default=lambda obj: obj.unique_identifiers
self.unique_identifiers, indent=4, default=lambda obj: getattr(obj,"unique_identifiers",None)
)
self.unique_description_hash = hashlib.md5(
self.unique_description.encode("utf-8")
Expand Down
Loading