Skip to content

Commit e85fd56

Browse files
committed
[bugfix] fix megatron pp4 max_epochs (#5432)
1 parent 8d89d50 commit e85fd56

File tree

5 files changed

+17
-12
lines changed

5 files changed

+17
-12
lines changed

docs/source/Instruction/支持的模型和数据集.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@
207207
|[swift/Qwen3-32B-AWQ](https://modelscope.cn/models/swift/Qwen3-32B-AWQ)|qwen3|qwen3|transformers>=4.51|✘|-|-|
208208
|[Qwen/Qwen3-4B-Instruct-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Instruct-2507)|qwen3|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507)|
209209
|[Qwen/Qwen3-4B-Instruct-2507-FP8](https://modelscope.cn/models/Qwen/Qwen3-4B-Instruct-2507-FP8)|qwen3|qwen3|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Instruct-2507-FP8](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507-FP8)|
210-
|[Qwen/Qwen3-4B-Thinking-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Thinking-2507](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507)|
210+
|[Qwen/Qwen3-4B-Thinking-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✔|-|[Qwen/Qwen3-4B-Thinking-2507](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507)|
211211
|[Qwen/Qwen3-4B-Thinking-2507-FP8](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507-FP8)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Thinking-2507-FP8](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507-FP8)|
212212
|[Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base)|qwen3_moe|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base)|
213213
|[Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B)|qwen3_moe|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)|

docs/source_en/Instruction/Supported-models-and-datasets.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ The table below introduces the models integrated with ms-swift:
207207
|[swift/Qwen3-32B-AWQ](https://modelscope.cn/models/swift/Qwen3-32B-AWQ)|qwen3|qwen3|transformers>=4.51|✘|-|-|
208208
|[Qwen/Qwen3-4B-Instruct-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Instruct-2507)|qwen3|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507)|
209209
|[Qwen/Qwen3-4B-Instruct-2507-FP8](https://modelscope.cn/models/Qwen/Qwen3-4B-Instruct-2507-FP8)|qwen3|qwen3|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Instruct-2507-FP8](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507-FP8)|
210-
|[Qwen/Qwen3-4B-Thinking-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Thinking-2507](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507)|
210+
|[Qwen/Qwen3-4B-Thinking-2507](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✔|-|[Qwen/Qwen3-4B-Thinking-2507](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507)|
211211
|[Qwen/Qwen3-4B-Thinking-2507-FP8](https://modelscope.cn/models/Qwen/Qwen3-4B-Thinking-2507-FP8)|qwen3_thinking|qwen3_thinking|transformers>=4.51|✘|-|[Qwen/Qwen3-4B-Thinking-2507-FP8](https://huggingface.co/Qwen/Qwen3-4B-Thinking-2507-FP8)|
212212
|[Qwen/Qwen3-30B-A3B-Base](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B-Base)|qwen3_moe|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-30B-A3B-Base](https://huggingface.co/Qwen/Qwen3-30B-A3B-Base)|
213213
|[Qwen/Qwen3-30B-A3B](https://modelscope.cn/models/Qwen/Qwen3-30B-A3B)|qwen3_moe|qwen3|transformers>=4.51|✔|-|[Qwen/Qwen3-30B-A3B](https://huggingface.co/Qwen/Qwen3-30B-A3B)|

swift/megatron/argument/megatron_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,10 @@ def create_group(ranks=None, timeout=None, *args, **kwargs):
357357
def __post_init__(self):
358358
require_version('numpy<2.0', 'Please install numpy<2.0 by running: `pip install "numpy<2.0"`.')
359359
if self.train_type == 'lora':
360-
require_version('peft>=0.12')
360+
if self.num_experts is not None:
361+
require_version('peft>=0.15')
362+
else:
363+
require_version('peft>=0.12')
361364
MegatronTunerMixin.__post_init__(self)
362365
os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] = '1'
363366
self._set_default()

swift/megatron/trainers/base.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,18 @@ def initialize_megatron(*_args, **kwargs):
7474
def new_cyclic_iter(iterable):
7575
args = get_args()
7676
i = 0
77+
n_batch = 0
7778
while True:
7879
is_training = getattr(args, 'is_training', False)
7980
if is_training:
8081
logger.info(f'The training of Epoch {i} starts...')
8182
if is_training and args.max_epochs and i >= args.max_epochs - 1:
8283
it = iter(iterable)
83-
num_batches = args.global_batch_size // (args.micro_batch_size * args.data_parallel_size)
84-
x = [next(it) for _ in range(num_batches)]
84+
num_microbatches = args.global_batch_size // (args.micro_batch_size * args.data_parallel_size)
85+
x = [next(it) for _ in range(num_microbatches - n_batch % num_microbatches)]
8586
while True:
8687
try:
87-
next_x = [next(it) for _ in range(num_batches)]
88+
next_x = [next(it) for _ in range(num_microbatches)]
8889
except StopIteration:
8990
break
9091
yield from x
@@ -94,6 +95,7 @@ def new_cyclic_iter(iterable):
9495
yield from x
9596
else:
9697
for x in iterable:
98+
n_batch += 1
9799
yield x
98100
i += 1
99101

swift/megatron/trainers/utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def _broadcast(item):
7373
_broadcast(batch['attention_mask'])
7474
_broadcast(batch['position_ids'])
7575
_broadcast(batch['loss_scale'])
76+
else:
77+
for key in ('input_ids', 'labels', 'attention_mask', 'position_ids', 'loss_scale'):
78+
batch[key] = None
7679

7780
else:
7881
flags = torch.empty((3), dtype=torch.int64, device=torch.cuda.current_device())
@@ -117,6 +120,8 @@ def _broadcast(item):
117120
_broadcast(attention_mask)
118121
_broadcast(position_ids) # compat packing & cp
119122
_broadcast(loss_scale)
123+
else:
124+
input_ids, labels, attention_mask, position_ids, loss_scale = (None, ) * 5
120125

121126
batch = {
122127
'input_ids': input_ids,
@@ -187,15 +192,10 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
187192

188193
def get_batch(data_iterator):
189194
"""Generate a batch."""
190-
191-
# TODO: this is pretty hacky, find a better way
192-
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
193-
return {key: None for key in ['input_ids', 'attention_mask', 'position_ids', 'loss_scale']}
194-
195195
# get batches based on the TP rank you are on
196196
batch = get_batch_on_this_tp_rank(data_iterator)
197197
args = get_args()
198-
if args.padding_free:
198+
if args.padding_free and batch.get('position_ids') is not None:
199199
batch['packed_seq_params'] = get_packed_seq_params(batch['position_ids'])
200200
# slice batch along sequence dimension for context parallelism
201201
batch = get_batch_on_this_cp_rank(batch)

0 commit comments

Comments
 (0)