Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/Megatron-SWIFT/快速开始.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu2
| torch | >=2.0 | 2.6.0/2.7.1 | |
| transformer_engine | >=2.3 | | |
| apex | | 0.1 | |
| megatron_core | >=0.12 | 0.13 | |
| megatron_core | | 0.13 | |
| flash_attn | | 2.8.1/3.0.0b1 | |
| transformers | >=4.33 | 4.56.2 | |
| modelscope | >=1.23 | | |
Expand Down
2 changes: 1 addition & 1 deletion docs/source_en/Megatron-SWIFT/Quick-start.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ Recommended Operating Environment:
| torch | >=2.0 | 2.6.0/2.7.1 | |
| transformer_engine | >=2.3 | | |
| apex | | 0.1 | |
| megatron_core | >=0.12 | 0.13 | |
| megatron_core | | 0.13 | |
| flash_attn | | 2.8.1/3.0.0b1 | |
| transformers | >=4.33 | 4.56.2 | |
| modelscope | >=1.23 | | |
Expand Down
2 changes: 1 addition & 1 deletion swift/llm/model/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def patch_tp_plan(load_model: bool):
transformers.__version__) < version.parse('4.50') or 'WORLD_SIZE' not in os.environ:
yield
return
logger.info('Patch tp_plan.')
logger.info_once('Patch tp_plan.')
WORLD_SIZE = os.environ.get('WORLD_SIZE')
os.environ['_PATCH_WORLD_SIZE'] = WORLD_SIZE
os.environ.pop('WORLD_SIZE')
Expand Down
6 changes: 4 additions & 2 deletions swift/megatron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,17 @@

if TYPE_CHECKING:
from .train import megatron_sft_main, megatron_pt_main, megatron_rlhf_main
from .utils import convert_hf2mcore, convert_mcore2hf, prepare_mcore_model, adapter_state_dict_context
from .convert import convert_hf2mcore, convert_mcore2hf
from .utils import prepare_mcore_model, adapter_state_dict_context
from .argument import MegatronTrainArguments, MegatronRLHFArguments
from .model import MegatronModelType, MegatronModelMeta, get_megatron_model_meta, register_megatron_model
from .trainers import MegatronTrainer, MegatronDPOTrainer
from .tuners import LoraParallelLinear
else:
_import_structure = {
'train': ['megatron_sft_main', 'megatron_pt_main', 'megatron_rlhf_main'],
'utils': ['convert_hf2mcore', 'convert_mcore2hf', 'prepare_mcore_model', 'adapter_state_dict_context'],
'convert': ['convert_hf2mcore', 'convert_mcore2hf'],
'utils': ['prepare_mcore_model', 'adapter_state_dict_context'],
'argument': ['MegatronTrainArguments', 'MegatronRLHFArguments'],
'model': ['MegatronModelType', 'MegatronModelMeta', 'get_megatron_model_meta', 'register_megatron_model'],
'trainers': ['MegatronTrainer', 'MegatronDPOTrainer'],
Expand Down
3 changes: 2 additions & 1 deletion swift/megatron/argument/train_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from swift.llm.argument.base_args import to_abspath
from swift.utils import add_version_to_work_dir, get_logger, init_process_group, is_master
from ..model import get_megatron_model_meta
from ..utils import convert_hf_config
from .megatron_args import MegatronArguments

logger = get_logger()
Expand All @@ -23,7 +24,7 @@ def init_model_args(self, tokenizer, config):
if self.task_type == 'seq_cls':
self.problem_type = self.problem_type or getattr(config, 'problem_type', None)
logger.info(f'args.problem_type: {self.problem_type}')
kwargs = self.megatron_model_meta.convert_hf_config(config)
kwargs = convert_hf_config(config)
if self.new_special_tokens and kwargs['padded_vocab_size'] < len(tokenizer):
kwargs['padded_vocab_size'] = math.ceil(len(tokenizer) / 128) * 128
self.initialize_embedding = True
Expand Down
24 changes: 15 additions & 9 deletions swift/megatron/utils/convert.py → swift/megatron/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from swift.llm import (ExportArguments, HfConfigFactory, prepare_model_template, save_checkpoint, to_device,
to_float_dtype)
from swift.utils import get_logger, get_n_params_grads
from ..argument import MegatronArguments
from ..model import get_megatron_model_meta
from .patcher import patch_megatron_tokenizer, patch_torch_dist_shard
from .argument import MegatronArguments
from .model import get_megatron_model_meta
from .utils import convert_hf_config, patch_torch_dist_shard

logger = get_logger()

Expand Down Expand Up @@ -238,15 +238,14 @@ def convert_hf2mcore(args: ExportArguments) -> None:

megatron_model_meta = get_megatron_model_meta(args.model_type)
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
kwargs = convert_hf_config(processor.model_info.config)
logger.info(f'megatron_config: {kwargs}')
_check_megatron_kwargs(kwargs)
current_convert_kwargs = convert_kwargs.copy()
if args.model_info.is_moe_model:
current_convert_kwargs['moe_grouped_gemm'] = True
megatron_args = MegatronArguments(
**kwargs, **current_convert_kwargs, save=args.output_dir, torch_dtype=args.torch_dtype)
patch_megatron_tokenizer(processor)
extra_args = megatron_args.parse_to_megatron()
extra_args['model_info'] = args.model_info
extra_args['model_meta'] = args.model_meta
Expand All @@ -256,7 +255,11 @@ def convert_hf2mcore(args: ExportArguments) -> None:

mg_model = megatron_model_meta.model_provider()
logger.info('Megatron model created successfully.')
megatron_model_meta.convert_hf2mcore(hf_model, mg_model)
bridge = megatron_model_meta.bridge_cls()
incompatible_keys = mg_model.load_state_dict(bridge.convert_hf2mcore(hf_model.state_dict()), strict=False)
missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')]
assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}'
assert len(missing_keys) == 0, f'missing_keys: {missing_keys}'
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
del hf_model
Expand All @@ -274,7 +277,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:

megatron_model_meta = get_megatron_model_meta(args.model_type)
assert megatron_model_meta is not None, f'Model: {args.model} is not supported.'
kwargs = megatron_model_meta.convert_hf_config(processor.model_info.config)
kwargs = convert_hf_config(processor.model_info.config)
logger.info(f'megatron_config: {kwargs}')
_check_megatron_kwargs(kwargs)
current_convert_kwargs = convert_kwargs.copy()
Expand All @@ -291,7 +294,6 @@ def convert_mcore2hf(args: ExportArguments) -> None:
**current_convert_kwargs,
save=args.output_dir if args.to_mcore else None,
torch_dtype=args.torch_dtype)
patch_megatron_tokenizer(processor)
extra_args = megatron_args.parse_to_megatron()
extra_args['model_info'] = args.model_info
extra_args['model_meta'] = args.model_meta
Expand All @@ -312,7 +314,11 @@ def convert_mcore2hf(args: ExportArguments) -> None:
logger.info('Megatron model created successfully.')
if args.to_hf:
hf_model, template = prepare_model_template(args, patch_offload=not args.test_convert_precision)
megatron_model_meta.convert_mcore2hf(hf_model, mg_model)
bridge = megatron_model_meta.bridge_cls()
incompatible_keys = hf_model.load_state_dict(bridge.convert_mcore2hf(mg_model.state_dict()), strict=False)
missing_keys = [k for k in incompatible_keys.missing_keys if not k.endswith('._extra_state')]
assert len(incompatible_keys.unexpected_keys) == 0, f'unexpected_keys: {incompatible_keys.unexpected_keys}'
assert len(missing_keys) == 0, f'missing_keys: {missing_keys}'
if args.test_convert_precision:
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
del mg_model
Expand Down
10 changes: 10 additions & 0 deletions swift/megatron/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,15 @@ def sharded_state_dict(
TEGroupedLinear.sharded_state_dict = sharded_state_dict


def _patch_megatron_tokenizer():
from megatron.training import global_vars

def build_tokenizer(args):
return 'dummy_tokenizer'

global_vars.build_tokenizer = build_tokenizer


def _patch_peft_ModulesToSaveWrapper():
if version.parse(peft.__version__) >= version.parse('0.16'):
from peft.utils import other as peft_module
Expand Down Expand Up @@ -664,6 +673,7 @@ def _patch_megatron():
_patch_compile_helpers()
_patch_build_train_valid_test_datasets()
_patch_mrope()
_patch_megatron_tokenizer()
logging.root.setLevel(logging_level) # revert logger level
from swift.megatron import tuners # patch lora
try:
Expand Down
3 changes: 2 additions & 1 deletion swift/megatron/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from . import gpt, mm_gpt
# from . import gpt, mm_gpt
from . import gpt
from .constant import MegatronModelType
from .register import MegatronModelMeta, get_megatron_model_meta, register_megatron_model
85 changes: 0 additions & 85 deletions swift/megatron/model/config.py

This file was deleted.

7 changes: 0 additions & 7 deletions swift/megatron/model/gpt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from ..gpt_model import GPTModel
from ..register import MegatronModelMeta, register_megatron_model
from . import qwen3_next
from .config import convert_gpt_hf_config
from .hf2mcore import convert_hf2mcore
from .mcore2hf import convert_mcore2hf

register_megatron_model(
MegatronModelMeta(
Expand Down Expand Up @@ -58,8 +55,4 @@
ModelType.deepseek_v3_1,
ModelType.ernie_thinking,
],
model_cls=GPTModel,
convert_hf_config=convert_gpt_hf_config,
convert_mcore2hf=convert_mcore2hf,
convert_hf2mcore=convert_hf2mcore,
))
Loading
Loading