Skip to content

Commit 7dd5f63

Browse files
authored
Set parallel config for megatron bridge (#1184)
1 parent 6dcd36e commit 7dd5f63

File tree

6 files changed

+30
-13
lines changed

6 files changed

+30
-13
lines changed

docker/patch/latest/megatron.patch

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -384,16 +384,15 @@ index a8f4abfcd..f33f6f05e 100755
384384

385385
if self.config.recompute_method == 'uniform':
386386
diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py
387-
index e2705bd9f..83a947c00 100644
387+
index e2705bd9f..a0aa109b5 100644
388388
--- a/megatron/core/transformer/transformer_config.py
389389
+++ b/megatron/core/transformer/transformer_config.py
390-
@@ -210,6 +210,10 @@ class TransformerConfig(ModelParallelConfig):
390+
@@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig):
391391
attention_output_gate: bool = False
392392
"""Whether to apply output gate to the attention layers."""
393393

394394
+ post_self_attn_layernorm: bool = False
395395
+ post_mlp_layernorm: bool = False
396-
+ use_gated_attention: bool = False
397396
+
398397
test_mode: bool = False
399398
"""Whether to run real-time tests."""
@@ -469,21 +468,20 @@ index 3ea405770..5a42001b9 100644
469468
# discard the output of the pre-mlp layernorm and register the recompute
470469
# as a gradient hook of mlp_output_with_bias[0]
471470
diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py
472-
index b267c8a81..def4ce809 100644
471+
index b267c8a81..83736acdc 100644
473472
--- a/megatron/training/arguments.py
474473
+++ b/megatron/training/arguments.py
475-
@@ -1398,6 +1398,10 @@ def core_transformer_config_from_args(args, config_class=None):
474+
@@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None):
476475

477476
kw_args['inference_sampling_seed'] = args.seed
478477

479478
+ kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm
480479
+ kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm
481-
+ kw_args['use_gated_attention'] = args.use_gated_attention
482480
+
483481
# handle quantization config
484482
# NOTE: Kitchen arguments are only added to the namespace when
485483
# Kitchen library is available.
486-
@@ -1764,6 +1768,12 @@ def _add_network_size_args(parser):
484+
@@ -1764,6 +1767,12 @@ def _add_network_size_args(parser):
487485
action='store_true',
488486
help='If set, use original BERT residula connection '
489487
'ordering.')

examples/geo3k_vlm/run_geo3k_vlm.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,8 @@ fi
8080
# Common args
8181
CKPT_ARGS=(
8282
--hf-checkpoint /root/models/${MODEL_NAME}
83+
# vl model has rotary base 5000000
84+
--rotary-base 5000000
8385
)
8486

8587
ROLLOUT_ARGS=(

slime/backends/megatron_utils/__init__.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ def new_init(self, *args, **kwargs):
2121
logging.warning("deep_ep is not installed, some functionalities may be limited.")
2222

2323
try:
24-
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import Qwen3VLTextRotaryEmbedding
24+
from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.text_model import (
25+
Qwen3VLMoETextRotaryEmbedding,
26+
Qwen3VLTextRotaryEmbedding,
27+
)
2528

26-
_original_forward = Qwen3VLTextRotaryEmbedding.forward
29+
def patch_rotary_embedding(cls):
30+
_original_forward = cls.forward
2731

28-
def _patched_forward(self, *args, packed_seq_params=None, **kwargs):
29-
return _original_forward(self, *args, **kwargs)
32+
def _patched_forward(self, *args, packed_seq_params=None, **kwargs):
33+
return _original_forward(self, *args, **kwargs)
3034

31-
Qwen3VLTextRotaryEmbedding.forward = _patched_forward
35+
cls.forward = _patched_forward
36+
37+
patch_rotary_embedding(Qwen3VLTextRotaryEmbedding)
38+
patch_rotary_embedding(Qwen3VLMoETextRotaryEmbedding)
3239
except ImportError:
3340
pass
3441

slime/backends/megatron_utils/model_provider.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ def get_model_provider_func(
5858

5959
bridge = AutoBridge.from_hf_pretrained(args.hf_checkpoint, trust_remote_code=True)
6060
provider = bridge.to_megatron_provider(load_weights=False)
61+
# TODO: we should not manually set this...
62+
provider.tensor_model_parallel_size = args.tensor_model_parallel_size
63+
provider.pipeline_model_parallel_size = args.pipeline_model_parallel_size
64+
provider.expert_model_parallel_size = args.expert_model_parallel_size
65+
provider.expert_tensor_parallel_size = args.expert_tensor_parallel_size
66+
provider.sequence_parallel = args.sequence_parallel
6167
provider.finalize()
6268
return provider.provide
6369

slime/utils/arguments.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1590,6 +1590,10 @@ def equal(x, y):
15901590

15911591
errors = []
15921592

1593+
# multimodal models have different config structure
1594+
if hasattr(hf_config, "text_config"):
1595+
hf_config = hf_config.text_config
1596+
15931597
for hf_config_name, megatron_config_name, compare_fn in [
15941598
("hidden_size", "hidden_size", equal),
15951599
("num_attention_heads", "num_attention_heads", equal),

tools/convert_hf_to_torch_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def main():
110110

111111
# Load model
112112
hf_model_path = args.hf_checkpoint
113-
bridge = AutoBridge.from_pretrained(hf_model_path, trust_remote_code=True)
113+
bridge = AutoBridge.from_hf_pretrained(hf_model_path, trust_remote_code=True)
114114
bridge.load_weights(model, hf_model_path, memory_efficient=True)
115115
print(f"Model loaded: {hf_model_path}")
116116

0 commit comments

Comments
 (0)