Skip to content

Commit 3f09697

Browse files
authored
[megatron] support lora modules_to_save (#4916)
1 parent ee503b8 commit 3f09697

File tree

9 files changed

+148
-36
lines changed

9 files changed

+148
-36
lines changed

examples/train/megatron/lora/dpo.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# 2 * 55GiB; 4.50s/it
1+
# 2 * 60GiB; 4.50s/it
22
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
33
NPROC_PER_NODE=2 \
44
CUDA_VISIBLE_DEVICES=0,1 \
@@ -10,6 +10,7 @@ megatron rlhf \
1010
--lora_rank 8 \
1111
--lora_alpha 32 \
1212
--target_modules all-linear \
13+
--modules_to_save word_embeddings output_layer \
1314
--split_dataset_ratio 0.01 \
1415
--expert_model_parallel_size 2 \
1516
--moe_grouped_gemm true \
@@ -29,7 +30,7 @@ megatron rlhf \
2930
--save megatron_output/Qwen3-30B-A3B-Base \
3031
--eval_interval 100 \
3132
--save_interval 100 \
32-
--max_length 8192 \
33+
--max_length 2048 \
3334
--num_workers 8 \
3435
--dataset_num_proc 8 \
3536
--no_save_optim true \
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# 2 * 60GiB, 3.4s/it
2+
PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' \
3+
NPROC_PER_NODE=2 \
4+
CUDA_VISIBLE_DEVICES=0,1 \
5+
megatron sft \
6+
--load Qwen3-30B-A3B-Base-mcore \
7+
--train_type lora \
8+
--dataset AI-ModelScope/function-calling-chatml#10000 \
9+
--loss_scale hermes \
10+
--agent_template hermes \
11+
--lora_rank 8 \
12+
--lora_alpha 32 \
13+
--target_modules all-linear \
14+
--modules_to_save word_embeddings output_layer \
15+
--split_dataset_ratio 0.01 \
16+
--expert_model_parallel_size 2 \
17+
--moe_grouped_gemm true \
18+
--moe_shared_expert_overlap true \
19+
--moe_aux_loss_coeff 0.01 \
20+
--micro_batch_size 8 \
21+
--global_batch_size 16 \
22+
--recompute_granularity full \
23+
--recompute_method uniform \
24+
--recompute_num_layers 1 \
25+
--max_epochs 1 \
26+
--finetune true \
27+
--cross_entropy_loss_fusion true \
28+
--lr 1e-4 \
29+
--lr_warmup_fraction 0.05 \
30+
--min_lr 1e-5 \
31+
--save megatron_output/Qwen3-30B-A3B-Base \
32+
--eval_interval 200 \
33+
--save_interval 200 \
34+
--max_length 2048 \
35+
--num_workers 8 \
36+
--dataset_num_proc 8 \
37+
--no_save_optim true \
38+
--no_save_rng true \
39+
--sequence_parallel true \
40+
--attention_backend flash

swift/megatron/init.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,49 @@ def sharded_state_dict(
639639
TEGroupedLinear.sharded_state_dict = sharded_state_dict
640640

641641

642+
def _patch_peft_ModulesToSaveWrapper():
643+
from peft.tuners import tuners_utils
644+
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
645+
from .utils import tuners_sharded_state_dict
646+
647+
ModulesToSaveWrapper = tuners_utils.ModulesToSaveWrapper
648+
649+
class NewModulesToSaveWrapper(ModulesToSaveWrapper):
650+
651+
def __init__(self, module_to_save, *args, **kwargs):
652+
tp_group = getattr(module_to_save, 'tp_group', None)
653+
if tp_group is not None:
654+
module_to_save.tp_group = None
655+
super().__init__(module_to_save, *args, **kwargs)
656+
if tp_group is not None:
657+
module_to_save.tp_group = tp_group
658+
for module in self.modules_to_save.values():
659+
module.tp_group = tp_group
660+
661+
def sharded_state_dict(
662+
self,
663+
prefix: str = '',
664+
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
665+
metadata: Optional[dict] = None,
666+
) -> ShardedStateDict:
667+
sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata)
668+
if prefix == 'output_layer.':
669+
output_layer_extra_state_key = f'{prefix}modules_to_save.default._extra_state'
670+
671+
# Old GPT checkpoints only stored the output layer weight key. So we remove the
672+
# _extra_state key but check that it doesn't contain any data anyway
673+
output_extra_state = sharded_state_dict.pop(output_layer_extra_state_key, None)
674+
assert not (output_extra_state and output_extra_state.data
675+
), f'Expected output layer extra state to be empty, got: {output_extra_state}'
676+
# fix error
677+
if f'{prefix}modules_to_save.default.weight' in sharded_state_dict:
678+
sharded_state_dict[f'{prefix}weight'] = sharded_state_dict[
679+
f'{prefix}modules_to_save.default.weight']
680+
return sharded_state_dict
681+
682+
tuners_utils.ModulesToSaveWrapper = NewModulesToSaveWrapper
683+
684+
642685
def _patch_megatron():
643686
_patch_transformer_engine()
644687
_patch__batched_p2p_ops()
@@ -647,7 +690,8 @@ def _patch_megatron():
647690
from swift.megatron import tuners # patch lora
648691
try:
649692
_patch_peft_BaseTuner()
650-
logger.info('Patch peft_BaseTuner successfully applied.')
693+
_patch_peft_ModulesToSaveWrapper()
694+
logger.info('Patch peft successfully applied.')
651695
except Exception:
652696
pass
653697
try:

swift/megatron/trainers/base.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from packaging import version
2121

2222
from swift.utils import JsonlWriter, get_logger, is_master
23-
from ..utils import adapter_state_dict_context, prepare_mcore_model
23+
from ..utils import adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model
2424
from .utils import get_swift_datasets_provider
2525

2626
logger = get_logger()
@@ -124,12 +124,20 @@ def _load_base_checkpoint(*_args, **kwargs):
124124
state_dict_model = {}
125125
mapping = {}
126126
for k, v in sharded_state_dict['model'].items():
127-
if 'lora_A' in k or 'lora_B' in k:
127+
if 'lora_A' in k or 'lora_B' in k or 'original_module' in k:
128128
continue
129-
origin_k = k
130-
k = k.replace('.base_layer', '')
131-
mapping[k] = origin_k
132-
v.key = v.key.replace('.base_layer', '')
129+
# lora
130+
if '.base_layer' in k:
131+
origin_k = k
132+
k = k.replace('.base_layer', '')
133+
mapping[k] = origin_k
134+
v.key = v.key.replace('.base_layer', '')
135+
elif '.modules_to_save' in k:
136+
# modules to save
137+
origin_k = k
138+
k = k.replace('.modules_to_save.default', '')
139+
mapping[k] = origin_k
140+
v.key = v.key.replace('.modules_to_save.default', '')
133141
state_dict_model[k] = v
134142
sharded_state_dict['model'] = state_dict_model
135143
res = origin__load_base_checkpoint(*_args, **kwargs)
@@ -168,6 +176,8 @@ def new_model_provider_func(*args, **kwargs):
168176
if args.adapter_load is not None:
169177
with adapter_state_dict_context():
170178
load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='adapter_load', strict=False)
179+
if args.train_type != 'full' and args.modules_to_save:
180+
copy_original_module_weight(self.unwrapped_model)
171181
return model, optimizer, opt_param_scheduler
172182

173183
def train_step(self, forward_step_func, data_iterator, model, optimizer, opt_param_scheduler, config):

swift/megatron/tuners/lora.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
1515
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory
1616
from megatron.core.transformer.module import MegatronModule
17-
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default
1817
from packaging import version
1918
from peft.tuners.lora import model
2019
from peft.tuners.lora.layer import LoraLayer
2120
from peft.tuners.tuners_utils import BaseTunerLayer, check_adapters_to_merge
2221
from peft.utils.other import transpose
2322

23+
from ..utils import tuners_sharded_state_dict
24+
2425

2526
class LoraParallelLinear(MegatronModule, LoraLayer):
2627

@@ -271,21 +272,7 @@ def sharded_state_dict(
271272
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
272273
metadata: Optional[dict] = None,
273274
) -> ShardedStateDict:
274-
sharded_state_dict = {}
275-
# Save parameters
276-
self._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
277-
sharded_state_dict = make_sharded_tensors_for_checkpoint(
278-
sharded_state_dict, prefix, sharded_offsets=sharded_offsets)
279-
# Recurse into submodules
280-
for name, module in self.named_children():
281-
if 'Dict' in module.__class__.__name__:
282-
modules = module.named_children()
283-
else:
284-
modules = [(None, module)]
285-
for n, m in modules:
286-
_prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.'
287-
sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata))
288-
275+
sharded_state_dict = tuners_sharded_state_dict(self, prefix, sharded_offsets, metadata)
289276
if prefix.endswith('linear_fc1.'):
290277
if isinstance(self.base_layer, TEGroupedLinear) and self.config.gated_linear_unit:
291278
num_global_experts = (parallel_state.get_expert_model_parallel_world_size() * self.base_layer.num_gemms)

swift/megatron/utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
from .convert import convert_hf2mcore, convert_mcore2hf
44
from .patcher import patch_megatron_tokenizer
5-
from .utils import adapter_state_dict_context, prepare_mcore_model
5+
from .utils import (adapter_state_dict_context, copy_original_module_weight, prepare_mcore_model,
6+
tuners_sharded_state_dict)

swift/megatron/utils/convert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,5 +221,4 @@ def convert_mcore2hf(args: ExportArguments) -> None:
221221
model_dirs=[ckpt_dir, args.model_dir],
222222
max_shard_size=args.max_shard_size,
223223
additional_saved_files=hf_model.model_meta.additional_saved_files)
224-
args.save_args()
225224
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')

swift/megatron/utils/utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from contextlib import contextmanager
3+
from typing import Optional, Tuple
34

45
import torch.distributed as dist
56
from megatron.core import mpu
67
from megatron.core.extensions.transformer_engine import TEGroupedLinear, TELayerNormColumnParallelLinear, TELinear
78
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
8-
from megatron.training import get_args
9+
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint, sharded_state_dict_default
10+
from megatron.training import checkpointing, get_args
911

1012
from swift.utils import activate_parameters, find_layers, freeze_parameters, get_logger, get_model_parameter_info
1113

@@ -96,7 +98,6 @@ def adapter_state_dict_context():
9698
if args.train_type == 'full':
9799
yield
98100
return
99-
from megatron.training import checkpointing
100101
_origin_generate_state_dict = checkpointing.generate_state_dict
101102

102103
def generate_state_dict(args, model, *_args, **kwargs):
@@ -121,3 +122,35 @@ def generate_state_dict(args, model, *_args, **kwargs):
121122
yield
122123
finally:
123124
checkpointing.generate_state_dict = _origin_generate_state_dict
125+
126+
127+
def tuners_sharded_state_dict(
128+
module,
129+
prefix: str = '',
130+
sharded_offsets: Tuple[Tuple[int, int, int]] = (),
131+
metadata: Optional[dict] = None,
132+
):
133+
sharded_state_dict = {}
134+
# Save parameters
135+
module._save_to_state_dict(sharded_state_dict, '', keep_vars=True)
136+
sharded_state_dict = make_sharded_tensors_for_checkpoint(
137+
sharded_state_dict, prefix, sharded_offsets=sharded_offsets)
138+
# Recurse into submodules
139+
for name, module in module.named_children():
140+
if 'Dict' in module.__class__.__name__:
141+
modules = module.named_children()
142+
else:
143+
modules = [(None, module)]
144+
for n, m in modules:
145+
_prefix = f'{prefix}{name}.' if n is None else f'{prefix}{name}.{n}.'
146+
sharded_state_dict.update(sharded_state_dict_default(m, _prefix, sharded_offsets, metadata))
147+
return sharded_state_dict
148+
149+
150+
def copy_original_module_weight(model):
151+
for module in model.modules():
152+
if 'ModulesToSaveWrapper' in module.__class__.__name__ and hasattr(module, 'original_module'):
153+
original_module = module.original_module
154+
modules_to_save = module.modules_to_save
155+
if 'default' in modules_to_save:
156+
original_module.load_state_dict(modules_to_save['default'].state_dict())

tests/megatron/test_lora.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@ def test_sft():
1212
loss_scale='hermes',
1313
split_dataset_ratio=0.01,
1414
tensor_model_parallel_size=2,
15-
load_from_cache_file=False,
1615
train_type='lora',
1716
recompute_granularity='full',
1817
recompute_method='uniform',
1918
recompute_num_layers=1,
2019
# pipeline_model_parallel_size=2,
2120
# freeze_parameters_ratio=0.5,
2221
train_iters=100,
22+
modules_to_save=['word_embeddings', 'output_layer'],
2323
eval_iters=5,
2424
save_interval=5,
2525
no_save_optim=True,
@@ -41,6 +41,7 @@ def test_moe():
4141
# expert_model_parallel_size=2,
4242
train_type='lora',
4343
recompute_granularity='full',
44+
modules_to_save=['word_embeddings', 'output_layer'],
4445
recompute_method='uniform',
4546
recompute_num_layers=1,
4647
# pipeline_model_parallel_size=2,
@@ -67,15 +68,11 @@ def test_embedding():
6768
pass
6869

6970

70-
def test_modules_to_save():
71-
pass
72-
73-
7471
def test_resume():
7572
pass
7673

7774

7875
if __name__ == '__main__':
79-
# test_sft()
76+
test_sft()
8077
# test_moe()
81-
test_convert()
78+
# test_convert()

0 commit comments

Comments
 (0)