forked from NVIDIA/Megatron-LM
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpretrain_gpt.py
More file actions
347 lines (289 loc) · 13.1 KB
/
pretrain_gpt.py
File metadata and controls
347 lines (289 loc) · 13.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
"""Pretrain and SFT GPT."""
# Capture the true program start time BEFORE any heavy imports.
import time
_PROGRAM_START_TIME = time.time()
import json
# Suppress warnings on all ranks but rank 0.
import os
import warnings
rank = int(os.environ.get('RANK', 0))
if rank != 0:
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
from functools import partial
from typing import List, Optional, Tuple
import torch
from gpt_builders import gpt_builder
from megatron.core import parallel_state
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.tokenizers.utils.build_tokenizer import build_tokenizer
from megatron.core.utils import get_attr_wrapped_model, get_thd_batch_on_this_cp_rank, get_batch_on_this_hybrid_cp_rank, StragglerDetector
from megatron.training import (
get_args,
get_timers,
inprocess_restart,
pretrain,
print_rank_0,
set_startup_timestamps,
)
from megatron.training.datasets.sft_dataset import SFTDataset
from megatron.core.transformer.multi_token_prediction import mtp_on_this_rank, get_mtp_ranks
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.datasets.fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig
from megatron.training.utils import (
get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank,
get_blend_and_blend_per_split,
is_first_or_last_pipeline_stage,
)
from model_provider import model_provider
try:
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
stimer = StragglerDetector()
def get_batch(data_iterator, vp_stage: Optional[int] = None):
"""Generate a batch."""
args = get_args()
config = core_transformer_config_from_args(args)
# TODO: this is pretty hacky, find a better way
if not is_first_or_last_pipeline_stage(vp_stage) and (
(not mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage))):
return None, None, None, None, None, None
# get batches based on the TP rank you are on
batch = get_batch_on_this_tp_rank(
data_iterator,
mtp_on_this_rank=mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage)
)
cu_seqlens = batch.pop('cu_seqlens', None)
cu_seqlens_padded = batch.pop('cu_seqlens_padded', None)
max_seqlen = batch.pop('max_seqlen', None)
local_cp_size = batch.pop('local_cp_size', None)
if local_cp_size is not None:
local_cp_size = int(local_cp_size.item())
if cu_seqlens is None and local_cp_size is None:
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch) # The implementation of this function is in MCore
packed_seq_params = None
elif local_cp_size is None: # Packed THD format
assert max_seqlen.dim() == 1
batch, packed_seq_params = get_thd_batch_on_this_cp_rank(batch, cu_seqlens, cu_seqlens_padded, max_seqlen)
else: # Hybrid CP format
batch, packed_seq_params = get_batch_on_this_hybrid_cp_rank(batch, local_cp_size)
return (*batch.values(), packed_seq_params)
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
"""Loss function.
Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns:
the loss scalar for this micro-batch
the number of non-padded tokens in this microbatch
a dict containing reporting metrics on the loss and number of tokens across
the data parallel ranks
"""
args = get_args()
if has_nvidia_modelopt and getattr(args, 'modelopt_enabled', False): # [ModelOpt]
loss, num_tokens, report = loss_func_modelopt(loss_mask, output_tensor, model=model)
else:
losses = output_tensor.view(-1).float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses * loss_mask)
num_tokens = loss_mask.sum().clone().detach().to(torch.int)
report = {'lm loss': torch.cat([loss.clone().detach().view(1), num_tokens.view(1)])}
# Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine()
if args.check_for_nan_in_loss_and_grad:
rerun_state_machine.validate_result(
result=loss,
rejection_func=torch.isnan,
message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss,
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss,
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
)
return loss, num_tokens, report
def forward_step(data_iterator, model: GPTModel, return_schedule_plan: bool = False):
"""Forward training step.
Args:
data_iterator : Input data iterator
model (GPTModel): The GPT Model
return_schedule_plan (bool): Whether to return the schedule plan instead of the output tensor
"""
args = get_args()
timers = get_timers()
# Get the batch.
timers('batch-generator', log_level=2).start()
global stimer
with stimer(bdata=True):
vp_stage = get_attr_wrapped_model(model, "vp_stage")
tokens, labels, loss_mask, attention_mask, position_ids, packed_seq_params = get_batch(data_iterator, vp_stage)
timers('batch-generator').stop()
with stimer:
if args.use_legacy_models:
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
if return_schedule_plan:
assert args.overlap_moe_expert_parallel_comm, \
"overlap_moe_expert_parallel_comm must be enabled to return the schedule plan"
schedule_plan = model.build_schedule_plan(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return schedule_plan, partial(loss_func, loss_mask, model=model)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask, packed_seq_params=packed_seq_params
)
# [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
def is_dataset_built_on_rank(vp_stage=None):
args = get_args()
config = core_transformer_config_from_args(args)
return (
is_first_or_last_pipeline_stage(vp_stage)
or mtp_on_this_rank(config, ignore_virtual=False, vp_stage=vp_stage)
) and parallel_state.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args):
tokenizer = build_tokenizer(args)
# Sometimes --data-path is too long, instead we parse it from a file.
blend: Optional[Tuple[List[str], Optional[List[float]]]]
blend_per_split: Optional[List[Optional[Tuple[List[str], Optional[List[float]]]]]]
blend, blend_per_split = get_blend_and_blend_per_split(args)
sequences_per_dataset = None
if args.per_dataset_sequences_path is not None:
with open(args.per_dataset_sequences_path, "r") as f:
sequences_per_dataset = json.load(f)
data_args = {
"random_seed": args.seed,
"sequence_length": args.seq_length,
"blend": blend,
"blend_per_split": blend_per_split,
"split": args.split,
"multiple_validation_sets": args.multiple_validation_sets,
"full_validation": args.full_validation,
"num_dataset_builder_threads": args.num_dataset_builder_threads,
"path_to_cache": args.data_cache_path,
"mmap_bin_files": args.mmap_bin_files,
"tokenizer": tokenizer,
"reset_position_ids": args.reset_position_ids,
"reset_attention_mask": args.reset_attention_mask,
"eod_mask_loss": args.eod_mask_loss,
"create_attention_mask": args.create_attention_mask_in_dataloader,
"object_storage_cache_path": args.object_storage_cache_path,
"mid_level_dataset_surplus": args.mid_level_dataset_surplus,
"allow_ambiguous_pad_tokens": args.allow_ambiguous_pad_tokens,
"fast_cache_load": args.dataloader_fast_cache_load,
"sequences_per_dataset": sequences_per_dataset,
"defer_npy_index_mmap": args.dataloader_defer_npy_index_mmap,
"context_parallel_size": args.context_parallel_size,
"data_parallel_size": args.data_parallel_size,
"sequence_parallel_size": args.tensor_model_parallel_size*args.sequence_parallel,
"hybrid_context_parallel": args.hybrid_context_parallel,
}
# add FIM args to the config
if args.fim_data:
extra_tokens = {
"prefix": args.fim_prefix_token,
"middle": args.fim_middle_token,
"suffix": args.fim_suffix_token,
"pad": args.fim_pad_token,
"eod": args.fim_eod_token,
}
data_args.update(
{
"fim_rate": args.fim_rate,
"fim_spm_rate": args.fim_spm_rate,
"fim_extra_tokens": extra_tokens,
"fim_split_sample": args.fim_split_sample,
"fim_fragment_rate": args.fim_fragment_rate,
"fim_no_prefix": args.fim_no_prefix,
}
)
return GPTFIMDatasetConfig(**data_args)
return GPTDatasetConfig(**data_args)
def train_valid_test_datasets_provider(train_val_test_num_samples, vp_stage=None):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples in train test and validation.
"""
args = get_args()
config = core_gpt_dataset_config_from_args(args)
if args.sft:
dataset_type = SFTDataset
else:
if args.mock_data:
dataset_type = MockGPTDataset
elif args.fim_data:
dataset_type = GPTFIMDataset
else:
dataset_type = GPTDataset
print_rank_0("> building train, validation, and test datasets for GPT ...")
is_dataset_built = partial(is_dataset_built_on_rank, vp_stage=vp_stage)
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, train_val_test_num_samples, partial(is_dataset_built_on_rank, vp_stage=vp_stage), config
).build()
print_rank_0("> finished creating GPT datasets ...")
return train_ds, valid_ds, test_ds
def get_embedding_ranks(pp_ranks: List[int]):
"""Get the embedding ranks."""
embedding_ranks = [pp_ranks[0]]
if len(pp_ranks) > 1:
args = get_args()
if not args.untie_embeddings_and_output_weights:
embedding_ranks.append(pp_ranks[-1])
config = core_transformer_config_from_args(args)
mtp_ranks = get_mtp_ranks(pp_ranks, config)
embedding_ranks.extend(mtp_ranks)
embedding_ranks = list(set(embedding_ranks))
embedding_ranks = sorted(embedding_ranks)
return embedding_ranks
if __name__ == "__main__":
# Timestamp right after entering __main__ block (after all imports/library setup)
_MAIN_ENTRY_TIME = time.time()
# Register startup timestamps for timing report in pretrain()
set_startup_timestamps(program_start=_PROGRAM_START_TIME, main_entry=_MAIN_ENTRY_TIME)
# Temporary for transition to core datasets
train_valid_test_datasets_provider.is_distributed = True
# Optionally enable inprocess restart on pretrain
pretrain, store = inprocess_restart.maybe_wrap_for_inprocess_restart(pretrain)
pretrain(
train_valid_test_datasets_provider,
partial(model_provider, gpt_builder),
ModelType.encoder_or_decoder,
forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
store=store,
get_embedding_ranks=get_embedding_ranks,
)