Skip to content

Commit a5fea81

Browse files
committed
update
1 parent 92ab184 commit a5fea81

File tree

6 files changed

+17
-15
lines changed

6 files changed

+17
-15
lines changed

swift/megatron/convert.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def convert_hf2mcore(args: ExportArguments) -> None:
7474
logger.info(f'Successfully saved Megatron model weights in `{args.output_dir}`.')
7575
# Place it at the end to avoid test_convert_precision affecting precision.
7676
if args.test_convert_precision:
77-
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
77+
test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype)
7878

7979

8080
def convert_mcore2hf(args: ExportArguments) -> None:
@@ -131,7 +131,7 @@ def convert_mcore2hf(args: ExportArguments) -> None:
131131
logger.info(f'Successfully saved HF model weights in `{args.output_dir}`.')
132132
if args.test_convert_precision:
133133
hf_model, template = prepare_model_template(args, model=args.output_dir)
134-
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
134+
test_convert_precision(megatron_args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype)
135135
elif args.to_mcore:
136136
if args.thread_count is None:
137137
checkpoint_size = sum(get_n_params_grads(mg_model)[0]) * torch.finfo(args.torch_dtype).bits // 8e9

swift/megatron/init.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -518,6 +518,7 @@ def sharded_state_dict(
518518
def _patch_TransformerLayer():
519519
import megatron.core
520520
from megatron.core.transformer import TransformerLayer
521+
from megatron.core import mpu
521522
_origin_forward = TransformerLayer.forward
522523
mcore_013 = version.parse(megatron.core.__version__) >= version.parse('0.13.0rc0')
523524

@@ -531,7 +532,7 @@ def forward(self, *_args, **kwargs):
531532
if not mcore_013:
532533
return _origin_forward(self, *_args, **kwargs)
533534
hidden_states, context = self._forward_attention(*_args, **kwargs)
534-
args = get_args()
535+
args = self.config.args
535536
mlp_padding_free = args.mlp_padding_free and 'attention_mask' in kwargs
536537
mask = None
537538
if mlp_padding_free and hidden_states.shape[1] > 1:
@@ -660,6 +661,7 @@ def _write_item(self, *args, **kwargs):
660661
def _patch_mrope():
661662
from megatron.core.models.common.embeddings.rotary_pos_embedding import MultimodalRotaryEmbedding
662663
import megatron.core
664+
from megatron.core import mpu
663665
from megatron.core.models.common.embeddings.rope_utils import _apply_rotary_pos_emb_bshd
664666
from megatron.core.models.common.embeddings import rope_utils
665667

@@ -696,7 +698,7 @@ def forward(self, position_ids, mrope_section: List[int], packed_seq: bool = Fal
696698
seq_expanded = seq[:, :, None, :].float()
697699
# shape (3, bs, seq_length, dim)
698700
freqs = (inv_freq_expanded @ seq_expanded).transpose(2, 3)
699-
args = get_args()
701+
args = self.config.args
700702
if args.mrope_interleaved:
701703
freqs = apply_interleaved_mrope(freqs, mrope_section)
702704
emb = torch.cat((freqs, freqs), dim=-1)
@@ -744,8 +746,7 @@ def _apply_rotary_pos_emb_thd(
744746
if cp_group is not None:
745747
cp_size = cp_group.size()
746748
else:
747-
args = get_args()
748-
cp_size = args.context_parallel_size
749+
cp_size = mpu.get_context_parallel_world_size()
749750
cu_seqlens_for_batched = cu_seqlens // cp_size
750751
use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item()
751752
if not use_batched_rope:

swift/megatron/pipelines/export/export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def convert_mcore2hf(self) -> None:
8282
device_map = args.device_map or 'auto'
8383
hf_model, template = prepare_model_template(
8484
args, device_map=device_map, **kwargs) if is_last_rank() else (None, template)
85-
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
85+
test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype)
8686
dist.barrier()
8787

8888
def convert_hf2mcore(self) -> None:
@@ -135,7 +135,7 @@ def convert_hf2mcore(self) -> None:
135135
device_map = args.device_map or 'auto'
136136
hf_model, template = prepare_model_template(
137137
args, device_map=device_map) if is_last_rank() else (None, template)
138-
test_convert_precision(hf_model, mg_model, template, args.test_convert_dtype)
138+
test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=args.test_convert_dtype)
139139
dist.barrier()
140140
else:
141141
logger.warning('Skip test_convert_precision because `--adapter_load` is specified.')

swift/megatron/trainers/gkd_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def _compute_teacher_logits(self, encoded_batches: List[Dict], vp_stage: Optiona
441441
teacher_data.pop('labels', None)
442442
# Teacher forward with args override for correct hidden_size
443443
with self.load_teacher_model_context(), self._teacher_args_context(), torch.no_grad():
444-
teacher_logits = forward_step_helper(teacher_model, teacher_data)
444+
teacher_logits = forward_step_helper(self.args, teacher_model, teacher_data)
445445
if teacher_logits is not None:
446446
teacher_logits = teacher_logits.detach()
447447
encoded_batch['teacher_logits'] = teacher_logits

swift/megatron/trainers/grpo_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1504,7 +1504,7 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
15041504
context = torch.no_grad() if no_grad else nullcontext()
15051505

15061506
with context:
1507-
output_tensor = forward_step_helper(model, data)
1507+
output_tensor = forward_step_helper(self.args, model, data)
15081508

15091509
# packed_seq_params only exists in padding_free mode
15101510
packed_seq_params = data.get('packed_seq_params')

swift/megatron/utils/convert_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,9 @@ def get_examples(is_multimodal: bool) -> Dict[str, Any]:
143143
return data
144144

145145

146-
def test_convert_precision(args, hf_model, mg_model, template):
147-
torch_dtype = args.test_convert_dtype
146+
def test_convert_precision(args, hf_model, mg_model, template, test_convert_dtype=None):
147+
if test_convert_dtype is None:
148+
test_convert_dtype = getattr(args, 'test_convert_dtype', torch.float32)
148149
template.set_mode('train')
149150
_test_params_sum(mg_model)
150151

@@ -166,7 +167,7 @@ def test_convert_precision(args, hf_model, mg_model, template):
166167
ignore_modules = (model_arch.vision_tower + model_arch.aligner) if is_multimodal else []
167168
hf_modules = _find_modules(hf_model, ignore_modules=ignore_modules)
168169
with torch.inference_mode(), _model_cpu_forward_context(
169-
hf_modules, torch_dtype, share_embedding=share_embedding):
170+
hf_modules, test_convert_dtype, share_embedding=share_embedding):
170171
hf_inputs.pop('text_position_ids', None)
171172
hf_logits = hf_model(**hf_inputs).logits
172173
hf_logits = hf_logits.to('cuda')
@@ -195,8 +196,8 @@ def test_convert_precision(args, hf_model, mg_model, template):
195196
if n.endswith('router'):
196197
m.to(mg_dtype)
197198
with torch.inference_mode(), _model_cpu_forward_context(
198-
mg_modules, torch_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device):
199-
mg_logits = forward_step_helper(mg_model, mg_inputs, dtype=torch_dtype)
199+
mg_modules, test_convert_dtype, 'cuda', share_embedding=share_embedding, target_device=mg_device):
200+
mg_logits = forward_step_helper(args, mg_model, mg_inputs, dtype=test_convert_dtype)
200201
if args.tensor_model_parallel_size > 1 and args.task_type != 'seq_cls':
201202
from megatron.core.tensor_parallel.mappings import gather_from_tensor_model_parallel_region
202203
if mg_logits is not None:

0 commit comments

Comments
 (0)