Skip to content

Commit 181e11e

Browse files
authored
[megatron] support megatron num_train_epochs (#4432)
1 parent 5bf6d1b commit 181e11e

File tree

8 files changed

+116
-88
lines changed

8 files changed

+116
-88
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,4 @@ Megatron训练参数继承自Megatron参数和基本参数。基本参数的内
300300
- 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。
301301
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。
302302
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。
303+
- 注意:如果你使用非流式数据集,该参数会为你自动计算train_iters,你不需要手动传入`train_iters`

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,3 +311,4 @@ Megatron training parameters inherit from Megatron parameters and basic paramete
311311
- 🔥streaming: Stream reading and processing of the dataset, default is False. It is typically set to True when handling large datasets. For more information on streaming parameters, refer to the command-line parameters documentation.
312312
- lazy_tokenize: Default is False. If this parameter is set to False, all dataset samples are tokenized before training (this avoids errors during training); if set to True, tokenization occurs during training (this saves memory).
313313
- max_epochs: Forces the training to exit after reaching `max_epochs`, and performs validation and saving of the model weights. This parameter is especially useful when using a streaming dataset. Default is None.
314+
- Note: If you use a non-streaming dataset, this parameter will automatically calculate train_iters for you, so there is no need to pass `train_iters` manually.

swift/cli/_megatron/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
logger = get_logger()
88

99
ROUTE_MAPPING: Dict[str, str] = {
10-
'sft': 'swift.cli._megatron.sft',
1110
'pt': 'swift.cli._megatron.pt',
11+
'sft': 'swift.cli._megatron.sft',
1212
}
1313

1414

swift/llm/model/patcher.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,8 @@ def new_get_cached_module_file(pretrained_model_name_or_path, *args, **kwargs):
350350

351351
@contextmanager
352352
def patch_tp_plan(load_model: bool):
353-
if not load_model or not is_mp_ddp() or version.parse(transformers.__version__) < version.parse('4.50'):
353+
if not load_model or not is_mp_ddp() or version.parse(
354+
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
354355
yield
355356
return
356357
WORLD_SIZE = os.environ.get('WORLD_SIZE')

swift/megatron/init.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
33
import sys
4-
from contextlib import contextmanager
54

65
from swift.llm import git_clone_github
76
from swift.utils import get_logger, is_megatron_available, safe_ddp_context, subprocess_run
@@ -30,50 +29,6 @@ def _patch_transformer_engine():
3029
pass
3130

3231

33-
def new_cyclic_iter(iter):
34-
from megatron.training import get_args
35-
args = get_args()
36-
max_epochs = args.max_epochs
37-
i = 0
38-
while True:
39-
if getattr(args, 'is_training', False):
40-
if max_epochs and i >= max_epochs:
41-
logger.info(f'Training of {i} epochs has been completed, the training has finished.')
42-
break
43-
logger.info(f'The training of Epoch {i} starts...')
44-
for x in iter:
45-
yield x
46-
i += 1
47-
48-
49-
@contextmanager
50-
def _training_context():
51-
from megatron.training import get_args
52-
args = get_args()
53-
args.is_training = True
54-
try:
55-
yield
56-
finally:
57-
args.is_training = False
58-
59-
60-
def _patch_max_epochs():
61-
# support max_epochs
62-
from megatron.training import training
63-
train_step_origin = training.train_step
64-
65-
def train_step(*args, **kwargs):
66-
with _training_context():
67-
try:
68-
return train_step_origin(*args, **kwargs)
69-
except StopIteration:
70-
return {}, True, True, True, 0, None, None
71-
72-
training.train_step = train_step
73-
74-
training.cyclic_iter = new_cyclic_iter
75-
76-
7732
def _patch__batched_p2p_ops():
7833
from megatron.core.pipeline_parallel import p2p_communication
7934

@@ -88,7 +43,6 @@ def _batched_p2p_ops(**kwargs):
8843

8944
def _patch_megatron():
9045
_patch_transformer_engine()
91-
_patch_max_epochs()
9246
_patch__batched_p2p_ops()
9347

9448

swift/megatron/train/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
12
from .pt import megatron_pt_main
23
from .sft import megatron_sft_main

swift/megatron/train/sft.py

Lines changed: 94 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
import os
3+
from contextlib import contextmanager
4+
from functools import partial
35
from typing import List, Union
46

7+
from megatron.core import mpu
58
from megatron.core.enums import ModelType
6-
from megatron.training import pretrain
9+
from megatron.core.utils import StragglerDetector
10+
from megatron.training import get_args, get_timers, pretrain, training
711

812
from swift.llm.train import SwiftSft
913
from swift.utils import get_logger, is_master, plot_images
1014
from ..argument import MegatronTrainArguments
1115
from ..utils import patch_megatron_tokenizer
1216
from .patcher import patch_megatron_data_collator
13-
from .utils import build_streaming_dataloader, forward_step, get_swift_datasets_provider
17+
from .utils import build_streaming_dataloader, get_batch, get_swift_datasets_provider
1418

1519
logger = get_logger()
1620

21+
stimer = StragglerDetector()
22+
1723

1824
class MegatronSft(SwiftSft):
1925
args_class = MegatronTrainArguments
@@ -30,8 +36,92 @@ def __init__(self, args: Union[List[str], MegatronTrainArguments, None] = None)
3036
self.template.use_megatron = True
3137
args.save_args(args.save)
3238

39+
@contextmanager
40+
def _get_train_iters(self, train_dataset):
41+
from megatron.training import training
42+
origin_initialize_megatron = training.initialize_megatron
43+
44+
def initialize_megatron(*_args, **kwargs):
45+
res = origin_initialize_megatron(*_args, **kwargs)
46+
args = get_args()
47+
if args.train_iters is None and hasattr(train_dataset, '__len__'):
48+
data_parallel_size = mpu.get_data_parallel_world_size()
49+
step_batch_size = \
50+
args.micro_batch_size * data_parallel_size
51+
dataset_sample = len(train_dataset) // step_batch_size * step_batch_size
52+
args.train_iters = (dataset_sample * args.max_epochs // args.global_batch_size) + 1
53+
return res
54+
55+
training.initialize_megatron = initialize_megatron
56+
try:
57+
yield
58+
finally:
59+
training.initialize_megatron = origin_initialize_megatron
60+
61+
@staticmethod
62+
def new_cyclic_iter(iter):
63+
args = get_args()
64+
max_epochs = args.max_epochs
65+
i = 0
66+
while True:
67+
if getattr(args, 'is_training', False):
68+
if max_epochs and i >= max_epochs:
69+
logger.info(f'Training of {i} epochs has been completed, the training has finished.')
70+
break
71+
logger.info(f'The training of Epoch {i} starts...')
72+
for x in iter:
73+
yield x
74+
i += 1
75+
76+
@staticmethod
77+
@contextmanager
78+
def _training_context():
79+
args = get_args()
80+
args.is_training = True
81+
try:
82+
yield
83+
finally:
84+
args.is_training = False
85+
86+
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config):
87+
return self._train_step_origin(forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config)
88+
89+
def _patch_train_step(self):
90+
# support max_epochs
91+
def train_step(*args, **kwargs):
92+
with self._training_context():
93+
try:
94+
return self.train_step(*args, **kwargs)
95+
except StopIteration:
96+
return {}, True, True, True, 0, None, None
97+
98+
self._train_step_origin = training.train_step
99+
training.train_step = train_step
100+
training.cyclic_iter = MegatronSft.new_cyclic_iter
101+
102+
def forward_step(self, data_iterator, model):
103+
from pretrain_gpt import loss_func
104+
105+
timers = get_timers()
106+
107+
# Get the batch.
108+
timers('batch-generator', log_level=2).start()
109+
global stimer
110+
with stimer(bdata=True):
111+
data = get_batch(data_iterator)
112+
if not data:
113+
raise StopIteration
114+
timers('batch-generator').stop()
115+
116+
with stimer:
117+
output_tensor = model(**data)
118+
labels = data.get('labels')
119+
loss_mask = None if labels is None else (labels != -100).float()
120+
return output_tensor, partial(loss_func, loss_mask)
121+
33122
def run(self):
34123
args = self.args
124+
self._patch_train_step()
35125

36126
train_dataset, val_dataset = self._get_dataset()
37127
train_dataset, val_dataset = self._encode_dataset(train_dataset, val_dataset)
@@ -46,13 +136,13 @@ def run(self):
46136
logging_path = os.path.join(args.save, 'logging.jsonl')
47137
logger.info(f'The logging file will be saved in: {logging_path}')
48138
try:
49-
with patch_megatron_data_collator(data_collator):
139+
with patch_megatron_data_collator(data_collator), self._get_train_iters(train_dataset):
50140
extra_args_provider = args.megatron_model_meta.extra_args_provider
51141
pretrain(
52142
datasets_provider,
53143
args.megatron_model_meta.model_provider,
54144
ModelType.encoder_or_decoder,
55-
forward_step,
145+
self.forward_step,
56146
extra_args_provider=extra_args_provider,
57147
args_defaults=args.extra_args)
58148
finally:

swift/megatron/train/utils.py

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from functools import partial
32
from typing import Any, Dict, Optional
43

54
import torch
65
from megatron.core import mpu
76
from megatron.core.packed_seq_params import PackedSeqParams
8-
from megatron.core.utils import StragglerDetector
97
from megatron.training import get_args, get_timers
108
from megatron.training.training import cyclic_iter
119

1210
from swift.llm import DataLoaderDispatcher
1311

14-
stimer = StragglerDetector()
15-
1612

1713
def get_swift_datasets_provider(train_dataset, val_dataset):
1814

@@ -67,10 +63,10 @@ def _broadcast(item):
6763
except StopIteration:
6864
seq_length = -1
6965
else:
70-
tokens = data['input_ids']
71-
seq_length = tokens.shape[1]
66+
input_ids = data['input_ids']
67+
seq_length = input_ids.shape[1]
7268
batch = {
73-
'tokens': tokens.cuda(non_blocking=True),
69+
'input_ids': input_ids.cuda(non_blocking=True),
7470
'labels': data['labels'].cuda(non_blocking=True),
7571
'attention_mask':
7672
None if 'attention_mask' not in data else data['attention_mask'].cuda(non_blocking=True),
@@ -81,13 +77,13 @@ def _broadcast(item):
8177
if seq_length.item() == -1:
8278
return {}
8379
if args.pipeline_model_parallel_size == 1:
84-
_broadcast(batch['tokens'])
80+
_broadcast(batch['input_ids'])
8581
_broadcast(batch['labels'])
8682
_broadcast(batch['attention_mask'])
8783
_broadcast(batch['position_ids'])
8884

8985
elif mpu.is_pipeline_first_stage():
90-
_broadcast(batch['tokens'])
86+
_broadcast(batch['input_ids'])
9187
_broadcast(batch['attention_mask'])
9288
_broadcast(batch['position_ids'])
9389

@@ -102,7 +98,7 @@ def _broadcast(item):
10298
if seq_length.item() == -1:
10399
return {}
104100
micro_batch_size = 1 # use qkv_format 'thd'
105-
tokens = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
101+
input_ids = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
106102
labels = torch.empty((micro_batch_size, seq_length), dtype=torch.int64, device=torch.cuda.current_device())
107103
if args.create_attention_mask_in_dataloader:
108104
attention_mask = torch.empty((micro_batch_size, 1, seq_length, seq_length),
@@ -115,26 +111,31 @@ def _broadcast(item):
115111
device=torch.cuda.current_device())
116112

117113
if args.pipeline_model_parallel_size == 1:
118-
_broadcast(tokens)
114+
_broadcast(input_ids)
119115
_broadcast(labels)
120116
_broadcast(attention_mask)
121117
_broadcast(position_ids)
122118

123119
elif mpu.is_pipeline_first_stage():
124120
labels = None
125121

126-
_broadcast(tokens)
122+
_broadcast(input_ids)
127123
_broadcast(attention_mask)
128124
_broadcast(position_ids)
129125

130126
elif mpu.is_pipeline_last_stage():
131-
tokens = None
127+
input_ids = None
132128

133129
_broadcast(labels)
134130
_broadcast(attention_mask)
135131
_broadcast(position_ids) # compat packing & cp
136132

137-
batch = {'tokens': tokens, 'labels': labels, 'attention_mask': attention_mask, 'position_ids': position_ids}
133+
batch = {
134+
'input_ids': input_ids,
135+
'labels': labels,
136+
'attention_mask': attention_mask,
137+
'position_ids': position_ids
138+
}
138139

139140
return batch
140141

@@ -213,25 +214,4 @@ def get_batch(data_iterator):
213214
batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids'])
214215
# slice batch along sequence dimension for context parallelism
215216
batch = get_batch_on_this_cp_rank(batch)
216-
return batch.values()
217-
218-
219-
def forward_step(data_iterator, model):
220-
from pretrain_gpt import loss_func
221-
222-
timers = get_timers()
223-
224-
# Get the batch.
225-
timers('batch-generator', log_level=2).start()
226-
global stimer
227-
with stimer(bdata=True):
228-
data = get_batch(data_iterator)
229-
if not data:
230-
raise StopIteration
231-
tokens, labels, attention_mask, position_ids, packed_seq_params = data
232-
timers('batch-generator').stop()
233-
234-
with stimer:
235-
output_tensor = model(tokens, position_ids, attention_mask, labels=labels, packed_seq_params=packed_seq_params)
236-
loss_mask = None if labels is None else (labels != -100).float()
237-
return output_tensor, partial(loss_func, loss_mask)
217+
return batch

0 commit comments

Comments
 (0)