Skip to content

Commit 97e36aa

Browse files
authored
[Main][feat] Support overlapping A2A Combine backprop with wgrad GEMM (NVIDIA#3795)
1 parent 15f14fc commit 97e36aa

File tree

12 files changed

+419
-19
lines changed

12 files changed

+419
-19
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,34 @@ class TrainingState(Enum):
7373
IDLE = auto()
7474

7575

76+
def setup_delayed_wgrad_acc_hook(module, grad_acc_func):
77+
"""Configure delayed wgrad gradient processing for MoE expert parameters.
78+
79+
When ``overlap_dispatch_backward_with_experts_wgrad`` is enabled on a TransformerLayer,
80+
this function:
81+
1. Marks expert parameters so the normal post-accumulate-grad hook is skipped.
82+
2. Registers a callback on the MoE layer that invokes FSDP's gradient
83+
reduce-scatter after the delayed wgrad computation completes.
84+
85+
Args:
86+
module: The module being processed in the forward pre-hook. Only
87+
``TransformerLayer`` instances with the delayed wgrad config flag
88+
enabled are affected; all other modules are no-ops.
89+
process_post_backward_gradients_fn: The FSDP gradient processing function
90+
(``_process_post_backward_gradients``) to be called after the delayed
91+
wgrad computation finishes.
92+
"""
93+
from functools import partial
94+
95+
need_backward_dw = getattr(module, "need_backward_dw", lambda: False)
96+
if not need_backward_dw():
97+
return
98+
99+
for param in module.parameters():
100+
if getattr(param, 'skip_backward_post_hook', False):
101+
param.post_wgrad_grad_acc_hook = partial(grad_acc_func, [param])
102+
103+
76104
class MegatronFSDP(torch.nn.Module):
77105
"""Fully Sharded Data Parallel training.
78106
@@ -662,6 +690,23 @@ def _process_post_backward_gradients(param_list):
662690
"""
663691
# Filter out shared parameters whose gradients are handled by the root hook.
664692
param_list = [p for p in param_list if not getattr(p, "_is_shared", False)]
693+
694+
# Filter out parameters whose gradient processing is deferred to a delayed
695+
# wgrad accumulation hook (post_wgrad_grad_acc_hook). If skip_backward_post_hook
696+
# is set but the delayed hook was never installed, process the parameter
697+
# immediately as a safety fallback to avoid silently dropping gradients.
698+
param_list = [
699+
p
700+
for p in param_list
701+
if not (
702+
getattr(p, 'skip_backward_post_hook', False)
703+
and hasattr(p, 'post_wgrad_grad_acc_hook')
704+
)
705+
]
706+
707+
if not param_list:
708+
return
709+
665710
for param in param_list:
666711
_grad_acc(param)
667712

@@ -728,6 +773,7 @@ def _pre_forward_param_unshard(
728773
prefetch=fsdp_forward_prefetch,
729774
prefetch_order=PrefetchOrder.FORWARD_PASS_ORDER,
730775
)
776+
731777
return args, kwargs
732778

733779
@torch.compiler.disable
@@ -983,6 +1029,8 @@ def _register_pre_backward_param_unshard_hook(module):
9831029

9841030
fsdp_modules = []
9851031
for name, module in root_module.named_modules():
1032+
# Set post backward hook for TE grouped gemm if enabled comm overlap
1033+
setup_delayed_wgrad_acc_hook(module, _process_post_backward_gradients)
9861034
if self.enable_fine_grained_param_gather_hook:
9871035
_register_pre_forward_param_unshard_hook(module)
9881036
_register_pre_backward_param_unshard_hook(module)

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2632,6 +2632,13 @@ def _reset_parameters(self, old_params, new_params):
26322632
if getattr(old_param, tp_attr, None) is not None:
26332633
setattr(new_param, tp_attr, getattr(old_param, tp_attr))
26342634

2635+
# For FSDP with delayed_wgrad_compute, `skip_backward_post_hook` needs
2636+
# to be reset on new param for correct grad accumulation of wgrad computation.
2637+
setattr(
2638+
new_param,
2639+
'skip_backward_post_hook',
2640+
getattr(old_param, 'skip_backward_post_hook', False),
2641+
)
26352642
for item_id, p in enumerate(self.params):
26362643
if p in param_map:
26372644
new_p = param_map[p]

megatron/core/extensions/transformer_engine.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1707,10 +1707,14 @@ def __init__(
17071707
self.disable_parameter_transpose_cache = self.config.disable_parameter_transpose_cache
17081708

17091709
extra_kwargs = _get_extra_te_kwargs(config)
1710+
self.delay_wgrad_compute = (
1711+
self.config.delay_wgrad_compute
1712+
or self.config.overlap_dispatch_backward_with_experts_wgrad
1713+
)
17101714

1711-
if self.config.delay_wgrad_compute:
1715+
if self.delay_wgrad_compute:
17121716
if is_te_min_version("2.3.0"):
1713-
extra_kwargs["delay_wgrad_compute"] = self.config.delay_wgrad_compute
1717+
extra_kwargs["delay_wgrad_compute"] = True
17141718
else:
17151719
raise RuntimeError(
17161720
"Only TE with version >=2.3.0 supports delay_wgrad_compute now."
@@ -2040,7 +2044,7 @@ def backward_dw(self):
20402044
Compute weight gradients during the backward pass
20412045
if delay_wgrad_compute is enabled.
20422046
"""
2043-
if self.config.delay_wgrad_compute:
2047+
if self.delay_wgrad_compute:
20442048
super().backward_dw()
20452049

20462050
class TEColumnParallelGroupedLinear(TEGroupedLinear):

megatron/core/model_parallel_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,15 @@ class ModelParallelConfig:
261261
delay_wgrad_compute: bool = False
262262
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
263263

264+
overlap_dispatch_backward_with_experts_wgrad: bool = False
265+
"""Delay the weight gradient computation for TE Grouped GEMM MoE experts.
266+
When enabled with FSDP, the expert weight gradients are computed on a separate
267+
CUDA stream after the data gradients finish, allowing overlap of wgrad compute
268+
with EP A2A communication. The FSDP gradient reduce-scatter for
269+
expert parameters is deferred until the delayed wgrad computation completes.
270+
This requires transformer_engine with GroupedLinear support (TE >= 2.3.0).
271+
"""
272+
264273
ep_overlap_early_attn_memory_release: bool = False
265274
"""Enable early memory release of attention activations during EP overlap.
266275
EP overlap can increase peak memory usage when the overlapped forward module allocates

megatron/core/transformer/moe/moe_layer.py

Lines changed: 86 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,9 @@ def __init__(
345345
self.cudagraph_tensor_store = MoECudaGraphTensorStore()
346346
self.fwd_execution_map = ["route", "expert_compute", "postprocess"]
347347

348+
# Setup events and streams for delayed wgrad computation.
349+
self.setup_delayed_wgrad_for_dispatch_backward_overlap()
350+
348351
def _setup_inference_mode(self, pg_collection):
349352
"""Set up inference-optimized token dispatcher and state.
350353
@@ -365,6 +368,16 @@ def _setup_inference_mode(self, pg_collection):
365368
pg_collection=pg_collection,
366369
)
367370

371+
def setup_delayed_wgrad_for_dispatch_backward_overlap(self):
372+
"""Initializes CUDA events and streams for overlapping expert
373+
weight gradient computation with dispatch backward.
374+
"""
375+
self._delayed_wgrad_event: Optional[torch.cuda.Event] = None
376+
self._delayed_wgrad_stream: Optional[torch.cuda.Stream] = None
377+
if self.config.overlap_dispatch_backward_with_experts_wgrad:
378+
self._delayed_wgrad_event = torch.cuda.Event()
379+
self._delayed_wgrad_stream = torch.cuda.Stream(device="cuda")
380+
368381
def set_inference_cuda_graphed_iteration(self):
369382
"""Enable CUDA-graphed iteration mode on this layer, its router, and its experts.
370383
@@ -435,6 +448,8 @@ def dispatch(self, hidden_states: torch.Tensor, probs: torch.Tensor):
435448
tokens and their associated probabilities to the devices hosting their assigned
436449
experts.
437450
"""
451+
if self.config.overlap_dispatch_backward_with_experts_wgrad:
452+
hidden_states = _RegisterDelayedWgradForExperts.apply(self, hidden_states)
438453
return self.token_dispatcher.token_dispatch(hidden_states, probs)
439454

440455
@maybe_skip_or_early_return_by_cudagraph("shared_experts_compute")
@@ -473,6 +488,10 @@ def routed_experts_compute(self, hidden_states: torch.Tensor, probs: torch.Tenso
473488
for each expert. It then passes the tokens through the local experts.
474489
The output from the experts is preprocessed for the combine step.
475490
"""
491+
if self.config.overlap_dispatch_backward_with_experts_wgrad:
492+
hidden_states = _RecordExpertDgradCompletion.apply(
493+
self._delayed_wgrad_event, hidden_states
494+
)
476495
dispatched_input, tokens_per_expert, permuted_probs = (
477496
self.token_dispatcher.dispatch_postprocess(hidden_states, probs)
478497
)
@@ -618,24 +637,24 @@ def custom_forward(hidden_states, intermediate_tensors=None, padding_mask=None):
618637

619638
def backward_dw(self, routed_experts: bool = True, shared_experts: bool = False):
620639
"""Compute weight gradients for experts and shared experts."""
640+
from megatron.core.pipeline_parallel.utils import get_comm_stream
641+
621642
# TODO(Wohox): replace the "routed_experts" and "shared_experts" arguments with better
622643
# naming to better explain that they are actually from different fine-grained callables,
623644
# or use scanning to decide which backward_dw should be called.
624645
if routed_experts:
625646
self.experts.backward_dw()
626-
if self.config.moe_latent_size:
647+
if self.config.moe_latent_size and self.config.overlap_moe_expert_parallel_comm:
627648
# TODO(Wohox): fc2_latent_proj forward and backward are executed in comm stream,
628649
# so we execute its backward_dw in the comm stream too. But this may harm the
629650
# EP overlap performance. Better to check if there is a better way to handle this.
630-
from megatron.core.pipeline_parallel.utils import get_comm_stream
631-
632651
comm_stream = get_comm_stream()
633652
with torch.cuda.stream(comm_stream):
634653
self.fc2_latent_proj.backward_dw()
635654
if shared_experts:
636655
if self.use_shared_expert and not self.shared_expert_overlap:
637656
self.shared_experts.backward_dw()
638-
if self.config.moe_latent_size:
657+
if self.config.moe_latent_size and self.config.overlap_moe_expert_parallel_comm:
639658
self.fc1_latent_proj.backward_dw()
640659

641660
def set_for_recompute_pre_mlp_layernorm(self):
@@ -646,3 +665,66 @@ def set_for_recompute_pre_mlp_layernorm(self):
646665
from megatron.core.extensions.transformer_engine import set_save_original_input
647666

648667
set_save_original_input(self.shared_experts.linear_fc1)
668+
669+
670+
class _RecordExpertDgradCompletion(torch.autograd.Function):
671+
"""Autograd function that records a CUDA event when expert data gradients finish.
672+
673+
Placed in the forward graph just before the expert computation so that during
674+
the backward pass, when the expert dgrad completes, we record an event. The
675+
subsequent ``_RegisterDelayedWgradForExperts`` waits on this event before
676+
launching the delayed wgrad computation on a separate CUDA stream.
677+
"""
678+
679+
@staticmethod
680+
def forward(ctx, event: torch.cuda.Event, *inputs):
681+
"""Forward pass that stores the event and passes through inputs unchanged."""
682+
ctx.event = event
683+
return inputs[0] if len(inputs) == 1 else inputs
684+
685+
@staticmethod
686+
def backward(ctx, *grad_outputs):
687+
"""Backward pass that records the event when expert dgrad completes."""
688+
ctx.event.record(torch.cuda.current_stream())
689+
ctx.event = None
690+
return (None,) + grad_outputs
691+
692+
693+
class _RegisterDelayedWgradForExperts(torch.autograd.Function):
694+
"""Autograd function that orchestrates delayed wgrad computation for MoE experts.
695+
696+
Placed in the forward graph at the dispatch boundary. During the backward pass,
697+
this function:
698+
1. Records an event on the current (backward) stream to signal the dgrad is done.
699+
2. Executes the delayed wgrad computation on a dedicated CUDA stream.
700+
3. Waits for the wgrad computation to complete.
701+
4. Invokes the registered gradient processing callback (e.g., FSDP reduce-scatter).
702+
"""
703+
704+
@staticmethod
705+
def forward(ctx, module: MoELayer, *inputs):
706+
"""Forward pass that stores the MoE module and passes through inputs unchanged."""
707+
ctx.module = module
708+
return inputs[0] if len(inputs) == 1 else inputs
709+
710+
@staticmethod
711+
def backward(ctx, *grad_outputs):
712+
"""Backward pass that executes delayed wgrad computation on a separate stream."""
713+
module = ctx.module
714+
event = module._delayed_wgrad_event
715+
wgrad_stream = module._delayed_wgrad_stream
716+
717+
wgrad_stream.wait_event(event)
718+
with torch.cuda.stream(wgrad_stream):
719+
with torch.cuda.nvtx.range("delayed_expert_wgrad"):
720+
module.backward_dw(routed_experts=True, shared_experts=False)
721+
event.record(wgrad_stream)
722+
723+
torch.cuda.current_stream().wait_event(event)
724+
725+
for param in module.parameters():
726+
if getattr(param, "post_wgrad_grad_acc_hook", None) is not None:
727+
param.post_wgrad_grad_acc_hook()
728+
729+
ctx.module = None
730+
return (None,) + grad_outputs

megatron/core/transformer/transformer_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2148,6 +2148,19 @@ def __post_init__(self):
21482148
'partial cuda graph'
21492149
)
21502150

2151+
if self.overlap_dispatch_backward_with_experts_wgrad:
2152+
assert not self.overlap_moe_expert_parallel_comm, (
2153+
'overlap_moe_expert_parallel_comm must be disabled when enabling '
2154+
'overlap_dispatch_backward_with_experts_wgrad.'
2155+
)
2156+
assert is_te_min_version(
2157+
"2.3.0"
2158+
), 'TE version >= 2.3.0 is required for overlap_dispatch_backward_with_experts_wgrad'
2159+
assert not self.delay_wgrad_compute, (
2160+
'delay_wgrad_compute and overlap_dispatch_backward_with_experts_wgrad '
2161+
'are mutually exclusive; use only one'
2162+
)
2163+
21512164
if self.ep_overlap_early_attn_memory_release:
21522165
assert self.overlap_moe_expert_parallel_comm, (
21532166
'overlap_moe_expert_parallel_comm must be enabled when enabling '

0 commit comments

Comments
 (0)