Skip to content

Commit 45ee0a4

Browse files
yuzhongw-nvidiaPhlip79
authored andcommitted
[main] feat(moe): Support moe shared expert gate for Qwen3-Next (2/4) (NVIDIA#2751)
Co-authored-by: Philip Petrakian <pgpetrak@gmail.com>
1 parent 11de188 commit 45ee0a4

File tree

5 files changed

+18
-4
lines changed

5 files changed

+18
-4
lines changed

megatron/core/models/gpt/moe_module_specs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def get_moe_module_spec_for_backend(
5757
experts = ModuleSpec(module=expert_module, submodules=expert_submodule)
5858

5959
# shared experts spec
60-
shared_experts = ModuleSpec(module=SharedExpertMLP, params={"gate": False}, submodules=mlp)
60+
shared_experts = ModuleSpec(module=SharedExpertMLP, submodules=mlp)
6161

6262
# MoE module spec
6363
moe_module_spec = ModuleSpec(

megatron/core/transformer/moe/moe_layer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,10 @@ def __init__(
189189
# Initialize shared experts
190190
if self.use_shared_expert:
191191
self.shared_experts = build_module(
192-
self.submodules.shared_experts, config=self.config, pg_collection=pg_collection
192+
self.submodules.shared_experts,
193+
config=self.config,
194+
pg_collection=pg_collection,
195+
gate=self.config.moe_shared_expert_gate,
193196
)
194197
if self.shared_expert_overlap:
195198
self.token_dispatcher.set_shared_experts(self.shared_experts)

megatron/core/transformer/transformer_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,10 @@ class TransformerConfig(ModelParallelConfig):
460460
different orders to the hidden_states, causing minor numerical differences
461461
in the hidden_states gradient."""
462462

463+
moe_shared_expert_gate: bool = False
464+
"""Enable gate for shared expert. Only effective when
465+
moe-shared-expert-intermediate-size is set."""
466+
463467
moe_shared_expert_overlap: bool = False
464468
"""Enable overlapping between shared expert computations and dispatcher communications.
465469
Without this, the shared experts execute before the router."""

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3153,6 +3153,8 @@ def _add_moe_args(parser):
31533153
'This makes the gradients from the router and the shared experts added in '
31543154
'different orders to the hidden_states, causing minor numerical differences '
31553155
'in the hidden_states gradient.')
3156+
group.add_argument('--moe-shared-expert-gate', action='store_true',
3157+
help='Enable gate for shared expert. Only effective when moe-shared-expert-intermediate-size is set.')
31563158
group.add_argument('--moe-shared-expert-overlap', action='store_true',
31573159
help='Enable overlapping between shared expert computations and dispatcher communications. '
31583160
'Without this, the shared experts execute before the router. '

tests/unit_tests/transformer/moe/test_shared_experts.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@ def teardown_method(self, method):
2020

2121
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
2222
@pytest.mark.internal
23-
def test_gpu_forward(self):
23+
@pytest.mark.parametrize("shared_expert_gate", [False, True])
24+
def test_gpu_forward(self, shared_expert_gate):
2425
Utils.initialize_model_parallel(1, 1)
2526
model_parallel_cuda_manual_seed(123)
2627
print("done intializing")
@@ -38,6 +39,7 @@ def test_gpu_forward(self):
3839
moe_router_load_balancing_type="sinkhorn",
3940
moe_router_topk=1,
4041
add_bias_linear=False,
42+
moe_shared_expert_gate=shared_expert_gate,
4143
)
4244
transformer_layer_spec = get_gpt_layer_local_spec(
4345
num_experts=num_moe_experts, moe_grouped_gemm=False
@@ -49,7 +51,10 @@ def test_gpu_forward(self):
4951
assert isinstance(self.moe_layer, MoELayer)
5052

5153
num_weights = sum([p.numel() for p in self.moe_layer.parameters()])
52-
assert num_weights == 3480 + 1152
54+
if shared_expert_gate:
55+
assert num_weights == 3480 + 1152 + 12 # 12 is the weight of the gate
56+
else:
57+
assert num_weights == 3480 + 1152
5358
assert self.moe_layer.shared_experts is not None
5459
assert self.moe_layer.shared_experts.stream is None
5560
assert self.moe_layer.token_dispatcher.shared_experts is None

0 commit comments

Comments
 (0)