Skip to content

Commit 873234c

Browse files
committed
Code clean.
1 parent df3a651 commit 873234c

File tree

2 files changed

+73
-19
lines changed

2 files changed

+73
-19
lines changed

megatron/core/transformer/moe/shared_experts.py

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,23 @@
22

33
import warnings
44
from copy import deepcopy
5+
from enum import Enum
6+
from functools import wraps
57
from typing import Optional
68

79
import torch
810
import torch.nn.functional as F
911

12+
13+
class SharedExpertState(Enum):
14+
"""State machine states for SharedExpertMLP overlapped forward pass."""
15+
16+
IDLE = 0
17+
PRE_FORWARD_COMM_DONE = 1
18+
FC1_FORWARD_DONE = 2
19+
FC2_FORWARD_DONE = 3
20+
POST_FORWARD_COMM_DONE = 4
21+
1022
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
1123
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
1224
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
@@ -27,6 +39,41 @@
2739
)
2840

2941

42+
def overlap_state_check(
43+
required_state: "SharedExpertState",
44+
next_state: "SharedExpertState",
45+
):
46+
"""
47+
Decorator to validate overlap state and cached variables before method execution,
48+
and update state after method execution.
49+
50+
Args:
51+
required_state: The expected SharedExpertState before this method runs.
52+
next_state: The SharedExpertState to transition to after method execution.
53+
"""
54+
55+
def decorator(method):
56+
@wraps(method)
57+
def wrapper(self, *args, **kwargs):
58+
# Check overlap is enabled
59+
assert self.config.moe_shared_expert_overlap, (
60+
f"{method.__name__} requires --moe-shared-expert-overlap to be set"
61+
)
62+
# Check state machine
63+
assert self._overlap_state == required_state, (
64+
f"{method.__name__} must be called from {required_state.name} state, "
65+
f"but current state is {self._overlap_state.name}"
66+
)
67+
# Execute method
68+
result = method(self, *args, **kwargs)
69+
# Update state after method execution
70+
self._overlap_state = next_state
71+
return result
72+
73+
return wrapper
74+
75+
return decorator
76+
3077
class _BackwardStreamWait(torch.autograd.Function):
3178
@staticmethod
3279
def forward(ctx, input, stream):
@@ -131,6 +178,9 @@ def __init__(
131178
self.cached_output = None
132179
self.gate_score = None
133180

181+
# State machine to ensure correct calling order of overlapped forward methods
182+
self._overlap_state = SharedExpertState.IDLE
183+
134184
if SharedExpertMLP.stream is None:
135185
SharedExpertMLP.stream = torch.cuda.Stream()
136186

@@ -163,14 +213,15 @@ def wait_current_stream(self):
163213
"""Wait for the current stream to complete."""
164214
self.stream.wait_stream(torch.cuda.current_stream())
165215

216+
@overlap_state_check(
217+
SharedExpertState.IDLE, SharedExpertState.PRE_FORWARD_COMM_DONE,
218+
)
166219
def pre_forward_comm(self, input, wait_current_stream=True):
167220
"""
168221
All Gather for SP before forward.
169222
This function is used to overlap shared experts with the dispatcher.
170223
It is only useful when --moe-shared-expert-overlap is set and may be changed.
171224
"""
172-
assert self.config.moe_shared_expert_overlap
173-
assert self.cached_output is None
174225
if wait_current_stream:
175226
self.wait_current_stream()
176227
with torch.cuda.stream(self.stream):
@@ -185,14 +236,15 @@ def pre_forward_comm(self, input, wait_current_stream=True):
185236
self.cached_fc1_input = copy_to_tensor_model_parallel_region(input)
186237
set_tensor_grad_fn_sequence_sr(self.cached_fc1_input, torch.iinfo(torch.int).max)
187238

239+
@overlap_state_check(
240+
SharedExpertState.PRE_FORWARD_COMM_DONE, SharedExpertState.FC1_FORWARD_DONE,
241+
)
188242
def linear_fc1_forward_and_act(self, overlapped_comm_output=None):
189243
"""
190244
Do Linear FC1 and activation function forward.
191245
This function is used to overlap shared experts with the dispatcher.
192246
It is only useful when --moe-shared-expert-overlap is set and may be changed.
193247
"""
194-
assert self.config.moe_shared_expert_overlap
195-
assert self.cached_fc1_input is not None
196248
with torch.cuda.stream(self.stream):
197249
# [s, b, 4 * h/p]
198250
intermediate_parallel, bias_parallel = self.linear_fc1(self.cached_fc1_input)
@@ -242,29 +294,31 @@ def glu(x):
242294
# Make sure the shared expert fc1 backward is launched after the routed fc1 backward
243295
self.cached_fc2_input = _BackwardStreamWait.apply(intermediate_parallel, self.stream)
244296

297+
@overlap_state_check(
298+
SharedExpertState.FC1_FORWARD_DONE, SharedExpertState.FC2_FORWARD_DONE,
299+
)
245300
def linear_fc2_forward(self, overlapped_comm_output=None):
246301
"""
247302
Do Linear FC2 forward.
248303
This function is used to overlap shared experts with the dispatcher.
249304
It is only useful when --moe-shared-expert-overlap is set and may be changed.
250305
"""
251-
assert self.config.moe_shared_expert_overlap
252-
assert self.cached_fc2_input is not None
253306
if overlapped_comm_output is not None:
254307
set_tensor_grad_fn_sequence_sr(overlapped_comm_output, torch.iinfo(torch.int).max)
255308
with torch.cuda.stream(self.stream):
256309
# [s, b, h]
257310
self.cached_fc2_output, _ = self.linear_fc2(self.cached_fc2_input)
258311
self.cached_fc2_input = None
259312

313+
@overlap_state_check(
314+
SharedExpertState.FC2_FORWARD_DONE, SharedExpertState.POST_FORWARD_COMM_DONE,
315+
)
260316
def post_forward_comm(self):
261317
"""
262318
Reduce scatter for SP after forward.
263319
This function is used to overlap shared experts with the dispatcher.
264320
It is only useful when --moe-shared-expert-overlap is set and may be changed.
265321
"""
266-
assert self.config.moe_shared_expert_overlap
267-
assert self.cached_fc2_output is not None
268322
with torch.cuda.stream(self.stream):
269323
if self.config.sequence_parallel:
270324
self.cached_output = reduce_scatter_to_sequence_parallel_region(
@@ -277,14 +331,15 @@ def post_forward_comm(self):
277331
self.cached_fc2_output = None
278332
set_tensor_grad_fn_sequence_sr(self.cached_output, torch.iinfo(torch.int).max)
279333

334+
@overlap_state_check(
335+
SharedExpertState.POST_FORWARD_COMM_DONE, SharedExpertState.IDLE,
336+
)
280337
def get_output(self):
281338
"""
282339
Gets the module forward output.
283340
This function is used to overlap shared experts with the dispatcher.
284341
It is only useful when --moe-shared-expert-overlap is set and may be changed.
285342
"""
286-
assert self.config.moe_shared_expert_overlap
287-
assert self.cached_output is not None
288343
with torch.cuda.stream(self.stream):
289344
if self.use_shared_expert_gate:
290345
assert self.gate_score is not None

megatron/core/transformer/moe/token_dispatcher.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -624,7 +624,7 @@ def token_dispatch(self, permutated_local_input_tokens, permuted_probs):
624624
Returns:
625625
A tuple of tokens and probabilities after All-to-All.
626626
"""
627-
# Make sure the shared experts fc1 is not launched before dispatch.
627+
# Make sure the shared experts fc1 is not launched too early when CUDA_DEVICE_MAX_CONNECTIONS>1.
628628
if self.shared_experts is not None:
629629
self.shared_experts.wait_current_stream()
630630
# Perform expert parallel AlltoAll communication
@@ -784,7 +784,8 @@ def token_combine(
784784
Returns:
785785
Tokens after the All-to-All communication for combining.
786786
"""
787-
# Make sure the shared experts fc2 is not launched before combine.
787+
# Make sure the shared experts fc2 is not overlapped with routed experts fc1
788+
# when CUDA_DEVICE_MAX_CONNECTIONS>1.
788789
if self.shared_experts is not None:
789790
self.shared_experts.wait_current_stream()
790791
# Perform expert parallel AlltoAll communication
@@ -796,6 +797,9 @@ def token_combine(
796797
self.output_splits,
797798
use_nccl_stream=True,
798799
)
800+
if self.shared_experts is not None:
801+
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
802+
self.shared_experts.post_forward_comm()
799803
return permutated_local_input_tokens
800804

801805
def combine_postprocess(self, permutated_local_input_tokens):
@@ -811,9 +815,6 @@ def combine_postprocess(self, permutated_local_input_tokens):
811815
Returns:
812816
The final MoE layer output reshaped to its original dimensions.
813817
"""
814-
if self.shared_experts is not None:
815-
self.shared_experts.linear_fc2_forward(permutated_local_input_tokens)
816-
self.shared_experts.post_forward_comm()
817818

818819
# Unpermutation 1: AlltoAll output to output
819820
output = unpermute(
@@ -1418,8 +1419,6 @@ def dispatch_preprocess(
14181419
# Initialize metadata
14191420
routing_map, probs = self._initialize_metadata(routing_map, probs)
14201421

1421-
if self.shared_experts is not None:
1422-
self.shared_experts.wait_current_stream()
14231422
self._comm_manager.setup_metadata(routing_map, probs)
14241423
return hidden_states, self._comm_manager.token_probs
14251424

@@ -1447,7 +1446,6 @@ def token_dispatch(
14471446
Returns:
14481447
A tuple of dispatched tokens and probabilities.
14491448
"""
1450-
# Make sure the shared experts fc1 is not launched before dispatch.
14511449
if self.shared_experts is not None:
14521450
self.shared_experts.wait_current_stream()
14531451
dispatched_hidden_states = self._comm_manager.dispatch(
@@ -1505,7 +1503,8 @@ def token_combine(
15051503
Returns:
15061504
Combined tokens after fused un-permutation and communication.
15071505
"""
1508-
# Make sure the shared experts fc2 is not launched before combine.
1506+
# Make sure the shared experts fc2 is not overlapped with routed experts GEMM
1507+
# when CUDA_DEVICE_MAX_CONNECTIONS>1.
15091508
if self.shared_experts is not None:
15101509
self.shared_experts.wait_current_stream()
15111510
return self._comm_manager.combine(hidden_states, async_finish, allocate_on_comm_stream)

0 commit comments

Comments
 (0)