22
33import warnings
44from copy import deepcopy
5+ from enum import Enum
6+ from functools import wraps
57from typing import Optional
68
79import torch
810import torch .nn .functional as F
911
12+
13+ class SharedExpertState (Enum ):
14+ """State machine states for SharedExpertMLP overlapped forward pass."""
15+
16+ IDLE = 0
17+ PRE_FORWARD_COMM_DONE = 1
18+ FC1_FORWARD_DONE = 2
19+ FC2_FORWARD_DONE = 3
20+ POST_FORWARD_COMM_DONE = 4
21+
1022from megatron .core .dist_checkpointing .mapping import ShardedStateDict
1123from megatron .core .fusions .fused_bias_geglu import bias_geglu_impl
1224from megatron .core .fusions .fused_bias_gelu import bias_gelu_impl
2739)
2840
2941
42+ def overlap_state_check (
43+ required_state : "SharedExpertState" ,
44+ next_state : "SharedExpertState" ,
45+ ):
46+ """
47+ Decorator to validate overlap state and cached variables before method execution,
48+ and update state after method execution.
49+
50+ Args:
51+ required_state: The expected SharedExpertState before this method runs.
52+ next_state: The SharedExpertState to transition to after method execution.
53+ """
54+
55+ def decorator (method ):
56+ @wraps (method )
57+ def wrapper (self , * args , ** kwargs ):
58+ # Check overlap is enabled
59+ assert self .config .moe_shared_expert_overlap , (
60+ f"{ method .__name__ } requires --moe-shared-expert-overlap to be set"
61+ )
62+ # Check state machine
63+ assert self ._overlap_state == required_state , (
64+ f"{ method .__name__ } must be called from { required_state .name } state, "
65+ f"but current state is { self ._overlap_state .name } "
66+ )
67+ # Execute method
68+ result = method (self , * args , ** kwargs )
69+ # Update state after method execution
70+ self ._overlap_state = next_state
71+ return result
72+
73+ return wrapper
74+
75+ return decorator
76+
3077class _BackwardStreamWait (torch .autograd .Function ):
3178 @staticmethod
3279 def forward (ctx , input , stream ):
@@ -131,6 +178,9 @@ def __init__(
131178 self .cached_output = None
132179 self .gate_score = None
133180
181+ # State machine to ensure correct calling order of overlapped forward methods
182+ self ._overlap_state = SharedExpertState .IDLE
183+
134184 if SharedExpertMLP .stream is None :
135185 SharedExpertMLP .stream = torch .cuda .Stream ()
136186
@@ -163,14 +213,15 @@ def wait_current_stream(self):
163213 """Wait for the current stream to complete."""
164214 self .stream .wait_stream (torch .cuda .current_stream ())
165215
216+ @overlap_state_check (
217+ SharedExpertState .IDLE , SharedExpertState .PRE_FORWARD_COMM_DONE ,
218+ )
166219 def pre_forward_comm (self , input , wait_current_stream = True ):
167220 """
168221 All Gather for SP before forward.
169222 This function is used to overlap shared experts with the dispatcher.
170223 It is only useful when --moe-shared-expert-overlap is set and may be changed.
171224 """
172- assert self .config .moe_shared_expert_overlap
173- assert self .cached_output is None
174225 if wait_current_stream :
175226 self .wait_current_stream ()
176227 with torch .cuda .stream (self .stream ):
@@ -185,14 +236,15 @@ def pre_forward_comm(self, input, wait_current_stream=True):
185236 self .cached_fc1_input = copy_to_tensor_model_parallel_region (input )
186237 set_tensor_grad_fn_sequence_sr (self .cached_fc1_input , torch .iinfo (torch .int ).max )
187238
239+ @overlap_state_check (
240+ SharedExpertState .PRE_FORWARD_COMM_DONE , SharedExpertState .FC1_FORWARD_DONE ,
241+ )
188242 def linear_fc1_forward_and_act (self , overlapped_comm_output = None ):
189243 """
190244 Do Linear FC1 and activation function forward.
191245 This function is used to overlap shared experts with the dispatcher.
192246 It is only useful when --moe-shared-expert-overlap is set and may be changed.
193247 """
194- assert self .config .moe_shared_expert_overlap
195- assert self .cached_fc1_input is not None
196248 with torch .cuda .stream (self .stream ):
197249 # [s, b, 4 * h/p]
198250 intermediate_parallel , bias_parallel = self .linear_fc1 (self .cached_fc1_input )
@@ -242,29 +294,31 @@ def glu(x):
242294 # Make sure the shared expert fc1 backward is launched after the routed fc1 backward
243295 self .cached_fc2_input = _BackwardStreamWait .apply (intermediate_parallel , self .stream )
244296
297+ @overlap_state_check (
298+ SharedExpertState .FC1_FORWARD_DONE , SharedExpertState .FC2_FORWARD_DONE ,
299+ )
245300 def linear_fc2_forward (self , overlapped_comm_output = None ):
246301 """
247302 Do Linear FC2 forward.
248303 This function is used to overlap shared experts with the dispatcher.
249304 It is only useful when --moe-shared-expert-overlap is set and may be changed.
250305 """
251- assert self .config .moe_shared_expert_overlap
252- assert self .cached_fc2_input is not None
253306 if overlapped_comm_output is not None :
254307 set_tensor_grad_fn_sequence_sr (overlapped_comm_output , torch .iinfo (torch .int ).max )
255308 with torch .cuda .stream (self .stream ):
256309 # [s, b, h]
257310 self .cached_fc2_output , _ = self .linear_fc2 (self .cached_fc2_input )
258311 self .cached_fc2_input = None
259312
313+ @overlap_state_check (
314+ SharedExpertState .FC2_FORWARD_DONE , SharedExpertState .POST_FORWARD_COMM_DONE ,
315+ )
260316 def post_forward_comm (self ):
261317 """
262318 Reduce scatter for SP after forward.
263319 This function is used to overlap shared experts with the dispatcher.
264320 It is only useful when --moe-shared-expert-overlap is set and may be changed.
265321 """
266- assert self .config .moe_shared_expert_overlap
267- assert self .cached_fc2_output is not None
268322 with torch .cuda .stream (self .stream ):
269323 if self .config .sequence_parallel :
270324 self .cached_output = reduce_scatter_to_sequence_parallel_region (
@@ -277,14 +331,15 @@ def post_forward_comm(self):
277331 self .cached_fc2_output = None
278332 set_tensor_grad_fn_sequence_sr (self .cached_output , torch .iinfo (torch .int ).max )
279333
334+ @overlap_state_check (
335+ SharedExpertState .POST_FORWARD_COMM_DONE , SharedExpertState .IDLE ,
336+ )
280337 def get_output (self ):
281338 """
282339 Gets the module forward output.
283340 This function is used to overlap shared experts with the dispatcher.
284341 It is only useful when --moe-shared-expert-overlap is set and may be changed.
285342 """
286- assert self .config .moe_shared_expert_overlap
287- assert self .cached_output is not None
288343 with torch .cuda .stream (self .stream ):
289344 if self .use_shared_expert_gate :
290345 assert self .gate_score is not None
0 commit comments