Skip to content

Commit 0c2bb7d

Browse files
authored
fix grpo resume_from_checkpoint (#4035)
1 parent 8d4a925 commit 0c2bb7d

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

docs/source/Instruction/Megatron-SWIFT训练.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
# Megatron-SWIFT训练
33

4-
SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、Qwen3-MoE、Llama3、Deepseek-R1蒸馏系等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](./支持的模型和数据集.md)
4+
SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/qwen3_moe.sh)、Qwen2.5、Llama3、Deepseek-R1蒸馏系等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](./支持的模型和数据集.md)
55

66
## 环境准备
77
使用Megatron-SWIFT,除了安装swift依赖外,还需要安装以下内容:
@@ -174,6 +174,7 @@ I am a language model developed by swift, you can call me swift-robot. How can I
174174

175175
**checkpoint参数**:
176176
- 🔥save: checkpoint的输出目录,默认None。在训练中,若未设置该参数,则默认为`f'megatron_output/{model_suffix}'`,例如`'megatron_output/Qwen2.5-7B-Instruct'`
177+
- 注意:若在多机训练时,请确保每个节点的保存路径指向相同位置。否则你需要在训练后手动集中这些权重。
177178
- 🔥save_interval: checkpoint保存的间隔(steps),默认为500。
178179
- 注意:训练结束时一定会保存权重。
179180
- 🔥no_save_optim: 不保存optimizer,默认为False。

docs/source_en/Instruction/Megatron-SWIFT-Training.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
# Megatron-SWIFT Training
33

4-
SWIFT incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports the pre-training and fine-tuning of models such as Qwen3, Qwen3-MoE, Llama3, and the Deepseek-R1 distillation series. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](./Supported-models-and-datasets.md).
4+
SWIFT incorporates Megatron's parallelization techniques to accelerate the training of large models, including data parallelism, tensor parallelism, pipeline parallelism, sequence parallelism, context parallelism, and expert parallelism. It supports the pre-training and fine-tuning of models such as Qwen3, [Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/qwen3_moe.sh), Qwen2.5, Llama3, and the Deepseek-R1 distillation series. For a complete list of supported models, please refer to the [Supported Models and Datasets documentation](./Supported-models-and-datasets.md).
55

66
## Environment Setup
77

@@ -181,6 +181,7 @@ seq_length: Defaults to None, meaning it is set to `max_length`. To restrict the
181181
**Checkpoint Parameters**:
182182

183183
- 🔥save: Output directory for checkpoints, default is None. During training, if this parameter is not set, it defaults to `f'megatron_output/{model_suffix}'`, e.g., `'megatron_output/Qwen2.5-7B-Instruct'`.
184+
- Note: When training on multiple machines, ensure that the save paths on each node point to the same location. Otherwise, you will need to manually consolidate these weights after training.
184185
- 🔥save_interval: Checkpoint saving interval (steps), default is 500.
185186
- Note: Weights will always be saved at the end of training.
186187
- 🔥no_save_optim: Do not save optimizer, default is False.

swift/llm/argument/train_args.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,10 +141,11 @@ def __post_init__(self) -> None:
141141
'Please specify `--attn_impl flash_attn`.')
142142
if self.resume_from_checkpoint:
143143
self.resume_from_checkpoint = to_abspath(self.resume_from_checkpoint, True)
144-
if self.train_type == 'full':
145-
self.model = self.resume_from_checkpoint
146-
else:
147-
self.adapters = [self.resume_from_checkpoint]
144+
if self.resume_only_model:
145+
if self.train_type == 'full':
146+
self.model = self.resume_from_checkpoint
147+
else:
148+
self.adapters = [self.resume_from_checkpoint]
148149
BaseArguments.__post_init__(self)
149150
Seq2SeqTrainingOverrideArguments.__post_init__(self)
150151
TunerArguments.__post_init__(self)

0 commit comments

Comments
 (0)