Skip to content

Commit 65a0651

Browse files
lhb8125Wohox
authored andcommitted
fix(moe): Support HybridEP and reduce memory overhead for 1F1B A2A overlap (NVIDIA#2236)
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com> Signed-off-by: Pingtian Li <pingtianl@nvidia.com> Co-authored-by: Pingtian Li <pingtianl@nvidia.com>
1 parent 132c4e0 commit 65a0651

File tree

7 files changed

+176
-22
lines changed

7 files changed

+176
-22
lines changed

megatron/core/model_parallel_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,19 @@ class ModelParallelConfig:
246246
delay_wgrad_compute: bool = False
247247
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
248248

249+
ep_overlap_early_attn_memory_release: bool = False
250+
"""Enable early memory release of attention activations during EP overlap.
251+
EP overlap can increase peak memory usage when the overlapped forward module allocates
252+
more memory than what is freed by the backward module. This flag addresses this by
253+
reordering the attention backward pass to occur earlier in the schedule.
254+
Specifically:
255+
- Without this flag: attn_bwd executes after moe_combine_fwd
256+
- With this flag: attn_bwd executes before mlp_fwd
257+
The earlier execution releases attention activations sooner, reducing peak memory.
258+
Note: This may impact performance as moe_combine_fwd and moe_dispatch_bwd become
259+
exposed (not overlapped with other computation).
260+
"""
261+
249262
###################
250263
# Pipeline Parallel
251264
###################

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7575
"""
7676
from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState
7777

78+
self.config = layer.config
7879
self.layer_state = TransformerLayerState()
7980
self.chunk_state = chunk_state
8081
self.layer = layer
@@ -85,6 +86,32 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8586
# get callable nodes for transformer/mtp layer
8687
self._build_callable_nodes(event, comp_stream, comm_stream, extra_args)
8788

89+
def release_state(self):
90+
"""Release reference, this helps avoid memory leak."""
91+
if hasattr(self, 'attn') and self.attn is not None:
92+
del self.attn
93+
self.attn = None
94+
if hasattr(self, 'post_attn') and self.post_attn is not None:
95+
del self.post_attn
96+
self.post_attn = None
97+
if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None:
98+
del self.moe_dispatch
99+
self.moe_dispatch = None
100+
if hasattr(self, 'mlp') and self.mlp is not None:
101+
del self.mlp
102+
self.mlp = None
103+
if hasattr(self, 'moe_combine') and self.moe_combine is not None:
104+
del self.moe_combine
105+
self.moe_combine = None
106+
if hasattr(self, 'mtp_post_process') and self.mtp_post_process is not None:
107+
del self.mtp_post_process
108+
self.mtp_post_process = None
109+
if hasattr(self, 'layer_state') and self.layer_state is not None:
110+
del self.layer_state
111+
self.layer_state = None
112+
if hasattr(self, 'layer'):
113+
del self.layer
114+
88115
def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
89116
"""
90117
Builds the callable nodes for the transformer/mtp layer:
@@ -112,7 +139,12 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
112139
self.layer.config.moe_token_dispatcher_type == "flex"
113140
and self.layer.config.moe_flex_dispatcher_backend == "deepep"
114141
)
142+
enable_hybridep = (
143+
self.layer.config.moe_token_dispatcher_type == "flex"
144+
and self.layer.config.moe_flex_dispatcher_backend == "hybridep"
145+
)
115146
extra_args["enable_deepep"] = enable_deepep
147+
extra_args["enable_hybridep"] = enable_hybridep
116148
extra_args["is_moe"] = is_moe
117149
extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute
118150
extra_args["is_mtp"] = is_mtp
@@ -219,6 +251,10 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
219251
b_layer.mlp.backward_dw()
220252
b_grad = b_layer.moe_dispatch.backward(b_grad)
221253

254+
if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release:
255+
b_grad = b_layer.post_attn.backward(b_grad)
256+
b_grad = b_layer.attn.backward(b_grad)
257+
222258
if f_layer is not None:
223259
with f_layer.get_fp8_context():
224260
f_input = f_layer.mlp.forward(f_input)
@@ -228,7 +264,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
228264
f_input = f_layer.moe_combine.forward(f_input)
229265
f_input = f_layer.mtp_post_process.forward(f_input)
230266

231-
if b_layer is not None:
267+
if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release:
232268
b_grad = b_layer.post_attn.backward(b_grad)
233269
b_grad = b_layer.attn.backward(b_grad)
234270

@@ -367,6 +403,10 @@ def get_layer(self, i):
367403
assert i < self.num_layers()
368404
return self._transformer_layers[i]
369405

406+
def pop_layer(self):
407+
"""Pops the transformer layer in FILO order."""
408+
return self._transformer_layers.pop()
409+
370410
def num_layers(self):
371411
"""Gets the number of transformer layers."""
372412
return len(self._transformer_layers)
@@ -445,27 +485,32 @@ def run(
445485
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
446486
overlapped_layers = min(f_num_layers, b_num_layers)
447487

488+
f_layer = b_layer = None
448489
# combined forward and backward pass for overlapped layers
449490
for i in range(overlapped_layers):
450491
f_layer = f_schedule_plan.get_layer(i)
451-
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
452-
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
492+
b_layer = b_schedule_plan.pop_layer()
493+
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
453494
f_input, b_grad = TransformerLayerSchedulePlan.run(
454495
f_layer,
455496
b_layer,
456497
f_input=f_input,
457498
b_grad=b_grad,
458499
is_last_layer_in_bwd=(i == b_num_layers - 1),
459500
)
501+
if i < b_num_layers - 1:
502+
b_layer.release_state()
460503
torch.cuda.nvtx.range_pop()
461504

462505
# backward pass for the remaining layers
463506
for i in range(overlapped_layers, b_num_layers):
464-
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
465-
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
507+
b_layer = b_schedule_plan.pop_layer()
508+
torch.cuda.nvtx.range_push(f"layer_{b_schedule_plan.num_layers()}b")
466509
_, b_grad = TransformerLayerSchedulePlan.run(
467510
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
468511
)
512+
if i < b_num_layers - 1:
513+
b_layer.release_state()
469514
torch.cuda.nvtx.range_pop()
470515

471516
# forward pass for the remaining layers
@@ -491,7 +536,9 @@ def run(
491536
# Delay the last attn_dw in backward pass (attn_dw of the first layer)
492537
# for overlapping with the p2p comm
493538
if b_num_layers > 0:
494-
b_schedule_plan.get_layer(0).attn.backward_dw()
539+
assert b_layer is not None
540+
b_layer.attn.backward_dw()
541+
b_layer.release_state()
495542

496543
# post process forward
497544
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
@@ -504,9 +551,7 @@ def run(
504551
f_schedule_plan.wait_current_stream()
505552
if b_schedule_plan:
506553
b_schedule_plan.wait_current_stream()
507-
508-
# Release reference as early as possible, this helps avoid memory leak.
509-
if b_schedule_plan is not None:
554+
# Release reference as early as possible, this helps avoid memory leak.
510555
b_schedule_plan.release_state()
511556

512557
return f_input

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 36 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_mtp_layer_offset,
1717
)
1818
from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor
19+
from megatron.core.utils import internal_api
1920

2021

2122
def weak_method(method):
@@ -35,13 +36,15 @@ def wrapped_func(*args, **kwarg):
3536
return wrapped_func
3637

3738

38-
def should_free_input(name, is_moe, is_deepep):
39+
@internal_api
40+
def should_free_input(name, is_moe, enable_deepep, enable_hybridep):
3941
"""Determine if the node should free its input memory.
4042
4143
Args:
4244
name: Node name
4345
is_moe: Whether it's a MoE model
44-
is_deepep: Whether it's a DeepEP model
46+
enable_deepep: Whether to use DeepEP dispatcher
47+
enable_hybridep: Whether to use HybridEP dispatcher
4548
4649
Returns:
4750
bool: Whether to free input memory
@@ -55,12 +58,13 @@ def should_free_input(name, is_moe, is_deepep):
5558
# The input and output of A2A are not needed anymore after the forward pass,
5659
# so we can free the input memory after the forward pass.
5760
free_input_nodes = {
58-
"mlp": True,
61+
"mlp": not enable_hybridep,
5962
"moe_combine": True,
60-
# For non-deepep mode, the input is the un-dispatched tokens and probs before dispatch A2A
61-
# and it's not needed anymore after the forward pass
62-
# For deepep mode, they are both needed in backward pass, so they cannot be freed.
63-
"moe_dispatch": not is_deepep,
63+
# For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens
64+
# and probs before dispatch A2A and it's not needed anymore after the forward pass
65+
# For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass
66+
# and cannot be freed.
67+
"moe_dispatch": not (enable_deepep or enable_hybridep),
6468
}
6569

6670
return free_input_nodes.get(name, False)
@@ -225,12 +229,13 @@ def __init__(
225229
it's the per_batch_state_context, o.w. nullcontext
226230
name (str): Node name, also used to determine memory strategy
227231
bwd_dw_callables (list): List of weight gradient functions for the layer.
228-
extra_args (dict): Extra arguments for the node: is_moe, enable_deepep.
232+
extra_args (dict): Extra arguments for nodes: is_moe, enable_deepep, enable_hybridep.
229233
"""
230234
# determine whether to free input memory
231235
is_moe = extra_args.get("is_moe", False)
232236
enable_deepep = extra_args.get("enable_deepep", False)
233-
free_input = should_free_input(name, is_moe, enable_deepep)
237+
enable_hybridep = extra_args.get("enable_hybridep", False)
238+
free_input = should_free_input(name, is_moe, enable_deepep, enable_hybridep)
234239
self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False)
235240

236241
super().__init__(
@@ -275,7 +280,13 @@ def backward_impl(self, outputs, output_grad):
275280
detached_grad = tuple([e.grad for e in self.detached])
276281
grads = output_grad + detached_grad
277282
self.default_backward_func(outputs + self.before_detached, grads)
278-
self._release_state()
283+
# release the output grad memory after backward finishes,
284+
# except when delay_wgrad_comptue is enabled, the grad should be
285+
# kept until all modules' backward_dw has been invoked.
286+
if self.delay_wgrad_compute:
287+
self.output_grads = grads
288+
self.delay_grads_release = len(self.bwd_dw_callables) > 0
289+
279290
# return grads for record stream
280291
return grads
281292

@@ -286,9 +297,17 @@ def backward_dw(self):
286297
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
287298
for module in self.bwd_dw_callables:
288299
module.backward_dw()
300+
301+
# the output grad memory is last used in wgrad compute, should be safe to release.
302+
assert self.delay_grads_release, "output grad memory should be valid before wgrad."
303+
if self.manual_release_grads:
304+
for tensor in self.output_grads:
305+
tensor.untyped_storage().resize_(0)
306+
self.output_grads = None
307+
289308
self.bwd_dw_callables = None
290309

291-
def _release_state(self):
310+
def __del__(self):
292311
# Release reference as early as possible, this helps avoid memory leak.
293312
self.before_detached = None
294313
self.detached = None
@@ -329,6 +348,10 @@ def build_transformer_layer_callables(layer: TransformerLayer):
329348
layer.config.moe_token_dispatcher_type == "flex"
330349
and layer.config.moe_flex_dispatcher_backend == "deepep"
331350
)
351+
enable_hybridep = (
352+
layer.config.moe_token_dispatcher_type == "flex"
353+
and layer.config.moe_flex_dispatcher_backend == "hybridep"
354+
)
332355

333356
def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor):
334357
"""
@@ -376,7 +399,7 @@ def submodule_dispatch_forward(
376399
Dispatches tokens to the experts based on the router output.
377400
"""
378401
token_dispatcher = layer.mlp.token_dispatcher
379-
if enable_deepep:
402+
if enable_deepep or enable_hybridep:
380403
# update token_probs to be the detached version, prevents
381404
# backward graph from connecting to attn submodule
382405
token_dispatcher._comm_manager.token_probs = probs
@@ -393,7 +416,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
393416
shared_expert_output = None
394417
dispatched_probs = node.layer_state.dispatched_probs
395418
token_dispatcher = layer.mlp.token_dispatcher
396-
if enable_deepep:
419+
if enable_deepep or enable_hybridep:
397420
# update dispatched_probs to be detached version, prevents
398421
# backward graph from connecting to dispatch submodule
399422
token_dispatcher._comm_manager.dispatched_probs = dispatched_probs

megatron/core/pipeline_parallel/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ def __init__(
149149
self.free_input = free_input
150150
self.inputs = None
151151
self.outputs = None
152+
self.delay_grads_release = False
153+
self.manual_release_grads = False
152154

153155
def default_backward_func(self, outputs, output_grad):
154156
"""Default backward function"""
@@ -230,6 +232,12 @@ def _backward(self, *output_grad):
230232
for g in output_grad:
231233
if g is not None:
232234
g.record_stream(self.stream)
235+
# Manually trigger the memory release of dgrad tensor
236+
# to avoid delayed garbage collection. If
237+
# delay_grads_release is True, dgrad is last used in
238+
# wgrad compute and skip the release here.
239+
if self.manual_release_grads and not self.delay_grads_release:
240+
g.untyped_storage().resize_(0)
233241

234242
grads = self.get_grad()
235243
self._release_state()

megatron/core/transformer/transformer_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,11 @@ def __post_init__(self):
17181718
assert (
17191719
self.mtp_num_layers is None or self.mtp_num_layers == 1
17201720
), 'MTP layernum only supports 1 when enabling overlap_moe_expert_parallel_comm.'
1721+
if self.mtp_num_layers == 1:
1722+
assert self.pipeline_model_parallel_size > 1, (
1723+
'Pipeline model parallel size must be larger than 1 '
1724+
'when enabling overlap_moe_expert_parallel_comm with MTP layer.'
1725+
)
17211726

17221727
# Check delay_wgrad_compute compatibility
17231728
if self.delay_wgrad_compute:
@@ -1728,6 +1733,12 @@ def __post_init__(self):
17281733
not self.moe_use_legacy_grouped_gemm
17291734
), 'delay_wgrad_compute is not supported with legacy groupedgemm implementation'
17301735

1736+
if self.ep_overlap_early_attn_memory_release:
1737+
assert self.overlap_moe_expert_parallel_comm, (
1738+
'overlap_moe_expert_parallel_comm must be enabled when enabling '
1739+
'ep_overlap_early_attn_memory_release'
1740+
)
1741+
17311742
if self.context_parallel_size > 1 and self.cp_comm_type is not None:
17321743
if isinstance(self.cp_comm_type, list):
17331744
assert len(self.cp_comm_type) == self.num_layers, (

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3291,6 +3291,8 @@ def _add_moe_args(parser):
32913291
help='Overlap the EP A2A communication by batch-level overlapping in 1f1b stage.')
32923292
group.add_argument('--delay-wgrad-compute', action='store_true',
32933293
help='Delay the wgrad compute for batch-level overlapping')
3294+
group.add_argument('--ep-overlap-early-attn-memory-release', action='store_true',
3295+
help='Release the memory of the attention module early in EP overlap.')
32943296

32953297
group.add_argument('--moe-upcycling-granularity', type=int, default=1,
32963298
help='This param sepecifics how many times smaller is the expert hidden size compared with the original dense FFN hidden size. '

0 commit comments

Comments
 (0)