Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
import types

from megatron.global_vars import set_retro_args, get_retro_args
from tools.retro.utils import get_args_path as get_retro_args_path


import megatron
from megatron.model.enums import PositionEmbeddingType


def get_args_path(workdir):
'''Argument copy stored within retro workdir.'''
return os.path.join(workdir, "args.json")


def parse_args(extra_args_provider=None, ignore_unknown_args=False):
"""Parse all arguments."""
parser = argparse.ArgumentParser(description='Megatron-LM Arguments',
Expand Down Expand Up @@ -400,6 +404,8 @@ def validate_args(args, defaults={}):
if args.use_flash_attn:
assert not args.reset_attention_mask, \
"Flash Attention doesn't support arbitrary attention masks. Please turn off reset-attention-mask"
else:
assert args.window_size is None

if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
Expand Down Expand Up @@ -579,6 +585,7 @@ def _add_network_size_args(parser):
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--window-size', type=int, default=None)
group.add_argument('--use-rotary-position-embeddings', action='store_true',
help='Use rotary positional embeddings or not')
group.add_argument('--rotary-percent', type=float, default=1.0,
Expand Down Expand Up @@ -697,13 +704,21 @@ def _add_logging_args(parser):
help="Name of wandb entity for reporting")
group.add_argument('--wandb-project-name', type=str, default=None,
help="Name of wandb project")
group.add_argument('--wandb-group-name', type=str, default="default",
help="Name of wandb entity for reporting")
group.add_argument('--transformer-timers', action='store_true',
help="If set, activate the timers within the transformer layers."
"Only for debugging, as this slows down the model.")
group.add_argument('--structured-logs', action="store_true",
help='Add timestamp and worker name to stdout and stderr.')
group.add_argument('--structured-logs-dir', type=str, default=None,
help='Directory to save the logs.')
group.add_argument('--debug_layer_outputs', '--debug-layer-outputs', type=int, default=0)
group.add_argument('--debug_layer_gradients', '--debug-layer-gradients', type=int, default=0)
group.add_argument('--debug_all_param_gradients', '--debug-all-param-gradients', type=int, default=0)
group.add_argument('--debug_param_init', '--debug-param-init', type=int, default=0)
group.add_argument('--debug_param_update', '--debug-param-update', type=int, default=0)
group.add_argument('--debug_transformer', '--debug-transformer', type=int, default=0)

return parser

Expand Down
129 changes: 64 additions & 65 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,71 +394,70 @@ def __getitem__(self, idx):
eod = self.tokenizer.eod
segment_breaks = np.argwhere(sample == eod) # split sample by document

if self.fim_rate == 0:
return sample.astype(np.int64)

def fim_permute_sequence(sequence, rate):
return permute(
sequence,
self.np_rng,
rate,
self.fim_spm_rate,
self.tokenizer,
truncate_or_pad=False,
suffix_tok_id=self.suffix_tok_id,
prefix_tok_id=self.prefix_tok_id,
middle_tok_id=self.middle_tok_id,
pad_tok_id=self.pad_tok_id,
)

def fim_split_and_permute_sequence(sequence):
"""
If self.fim_split_sample is not None, split the sequence.
Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None.
"""
if self.fim_split_sample is None:
return fim_permute_sequence(sequence, self.fim_rate)
# fim_split_sample is set: split the sample on this token and permute each fragment separately.
# Typically, if each sample is a repository, then we split again on the file level.
# Each fragment is a file, and we permute the files.
fragment_breaks = np.argwhere(sequence == self.fim_split_sample)
if fragment_breaks.shape == (0, 1):
# no split token in this sample
return fim_permute_sequence(sequence, self.fim_rate)
if not self.np_rng.binomial(1, self.fim_rate):
# don't do FIM preproc
return sequence
# Do FIM on each fragment
curr_start_position = 0
new_samples = []
for loc in np.nditer(fragment_breaks):
if loc - curr_start_position > 0:
permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate)
new_samples += [permuted, [self.fim_split_sample]]
curr_start_position = loc + 1 # Jump over the split token
# Permute the segment after the last split token
permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate)
new_samples.append(permuted)
return np.concatenate(new_samples)

if segment_breaks.shape != (0, 1): # then there is an EOD token in this example
curr_start_position = 0
new_samples = []
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc])
new_samples += [permuted, [eod]]

curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = fim_split_and_permute_sequence(sample[curr_start_position:])
new_samples.append(permuted)

sample = np.concatenate(new_samples)
else:
sample = fim_split_and_permute_sequence(sample)
if self.fim_rate != 0:

def fim_permute_sequence(sequence, rate):
return permute(
sequence,
self.np_rng,
rate,
self.fim_spm_rate,
self.tokenizer,
truncate_or_pad=False,
suffix_tok_id=self.suffix_tok_id,
prefix_tok_id=self.prefix_tok_id,
middle_tok_id=self.middle_tok_id,
pad_tok_id=self.pad_tok_id,
)

def fim_split_and_permute_sequence(sequence):
"""
If self.fim_split_sample is not None, split the sequence.
Then apply FIM on the fragments, or the whole sequence if self.fim_split_sample is None.
"""
if self.fim_split_sample is None:
return fim_permute_sequence(sequence, self.fim_rate)
# fim_split_sample is set: split the sample on this token and permute each fragment separately.
# Typically, if each sample is a repository, then we split again on the file level.
# Each fragment is a file, and we permute the files.
fragment_breaks = np.argwhere(sequence == self.fim_split_sample)
if fragment_breaks.shape == (0, 1):
# no split token in this sample
return fim_permute_sequence(sequence, self.fim_rate)
if not self.np_rng.binomial(1, self.fim_rate):
# don't do FIM preproc
return sequence
# Do FIM on each fragment
curr_start_position = 0
new_samples = []
for loc in np.nditer(fragment_breaks):
if loc - curr_start_position > 0:
permuted = fim_permute_sequence(sequence[curr_start_position:loc], self.fragment_fim_rate)
new_samples += [permuted, [self.fim_split_sample]]
curr_start_position = loc + 1 # Jump over the split token
# Permute the segment after the last split token
permuted = fim_permute_sequence(sequence[curr_start_position:], self.fragment_fim_rate)
new_samples.append(permuted)
return np.concatenate(new_samples)

if segment_breaks.shape != (0, 1): # then there is an EOD token in this example
curr_start_position = 0
new_samples = []
for loc in np.nditer(segment_breaks):
# Only permute non-empty segments.
if loc - curr_start_position > 0:
# permute {prefix, suffix, middle} or {suffix, prefix, middle}
permuted = fim_split_and_permute_sequence(sample[curr_start_position:loc])
new_samples += [permuted, [eod]]

curr_start_position = loc + 1 # jump over the EOD token
# Permute the segment after the last EOD
permuted = fim_split_and_permute_sequence(sample[curr_start_position:])
new_samples.append(permuted)

sample = np.concatenate(new_samples)
else:
sample = fim_split_and_permute_sequence(sample)

# Truncate or pad sequence to max-length
diff = sample.shape[0] - sample_len
Expand Down
36 changes: 25 additions & 11 deletions megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import logging
import logging.config
import math
import random
import os
import sys
import time

import numpy as np
import torch
import torch.distributed
from datetime import timedelta

try:
Expand Down Expand Up @@ -96,16 +98,20 @@ def finish_mpu_init():

def _configure_logging():
args=get_args()
if not args.structured_logs:
return
rank = torch.distributed.get_rank()
if args.structured_logs:
world_size=torch.distributed.get_world_size()
rank_str = str(rank).zfill(math.ceil(math.log10(world_size)))
format = f"%(asctime)s {'' if world_size==1 else f'[Rank {rank_str}] '}%(message)s"
else:
format=None

logging_config = {
"version": 1,
"disable_existing_loggers": False,
"formatters": {
"default": {
"format": f"%(asctime)s [Rank {rank}]: %(message)s",
"format": format,
"use_colors": True,
}
},
Expand Down Expand Up @@ -133,13 +139,13 @@ def _configure_logging():
logging_config["loggers"]["default"]["handlers"].append("file")
logging.config.dictConfig(logging_config)

# Add these methods so that stdout can be redirected to logging.
logging.write = lambda msg: logging.info(msg) if msg != '\n' else None
logging.flush = lambda : None

sys.stdout=logging
sys.stderr=logging
if args.structured_logs:
# Add these methods so that stdout can be redirected to logging.
logging.write = lambda msg: logging.info(msg) if msg != '\n' else None
logging.flush = lambda : None

sys.stdout=logging
sys.stderr=logging


def _compile_dependencies():
Expand Down Expand Up @@ -298,6 +304,10 @@ def write_args_to_tensorboard():

def init_wandb():
args = get_args()
# Wandb login from file
api_key_path = os.environ.get("WANDB_API_KEY_PATH")
if api_key_path:
os.environ["WANDB_API_KEY"]=open(api_key_path,"r").read().strip()
if args.rank == (args.world_size - 1):
if not (args.wandb_entity_name and args.wandb_project_name):
print('> Skipping wandb init ...', flush=True)
Expand All @@ -306,7 +316,7 @@ def init_wandb():
name=os.path.basename(args.save),
entity=args.wandb_entity_name,
project=args.wandb_project_name,
group="mini_cluster",
group=args.wandb_group_name,
config=args
)

Expand All @@ -332,7 +342,11 @@ def set_jit_fusion_options():
torch._C._jit_override_can_fuse_on_cpu(True)
torch._C._jit_override_can_fuse_on_gpu(True)

_warmup_jit_function()
# Prevent the function from messing up the random state.

tensor_parallel.get_cuda_rng_tracker().add("Warmup jit", 0)
with tensor_parallel.get_cuda_rng_tracker().fork("Warmup jit"):
_warmup_jit_function()


def _warmup_jit_function():
Expand Down
34 changes: 33 additions & 1 deletion megatron/model/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def __init__(self,
else:
self.tokentype_embeddings = None

self.fp32_residual_connection = args.fp32_residual_connection
self.fp32_residual_connection = args.fp32_residual_connection
self.sequence_parallel = args.sequence_parallel
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
Expand Down Expand Up @@ -431,6 +431,38 @@ def __init__(self,
perform_initialization=args.perform_initialization)
self._output_layer_key = 'output_layer'

for i, (key, value) in enumerate(self.named_parameters()):
# Store standardized parameter names for debug purposes.
args=get_args()
key=key.split(".")
if key[0]=="encoder":
# Remove "encoder" prefix.
key=key[1:]
if key[0]=="layers":
# Shift layer index.
key[1]=str(int(key[1])+1)
if key[2]=="input_layernorm":
key[2]="layer_norm_1"
elif key[2]=="post_attention_layernorm":
key[2]="layer_norm_2"
elif key[2]=="self_attention":
key[2]="self_attn"
elif key[3]=="dense_h_to_4h":
key[3]="layer_1"
elif key[3]=="dense_4h_to_h":
key[3]="layer_2"
else:
assert key[0]=="final_layernorm"
key=["layers",str(args.encoder_num_layers+1)]+key
elif key[0]=="embedding":
key=["layers", "0", "_".join(key[1:])]
else:
# Not implemented but still ok
pass

value.param_name = ".".join(key)
value.param_idx = i

def set_input_tensor(self, input_tensor):
""" See megatron.model.transformer.set_input_tensor()"""

Expand Down
Loading