Skip to content

Commit c60d5c2

Browse files
lhb8125rootyanringWohox
authored
[Dev] fix(moe): Support HybridEP and reduce memory overhead for 1F1B A2A overlap (#2201)
Signed-off-by: Hongbin Liu <hongbinl@nvidia.com> Signed-off-by: Pingtian Li <pingtianl@nvidia.com> Co-authored-by: root <root@eos0318.eos.clusters.nvidia.com> Co-authored-by: Zijie Yan <zijiey@nvidia.com> Co-authored-by: Pingtian Li <pingtianl@nvidia.com>
1 parent 23e092f commit c60d5c2

File tree

7 files changed

+174
-22
lines changed

7 files changed

+174
-22
lines changed

megatron/core/model_parallel_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,19 @@ class ModelParallelConfig:
265265
delay_wgrad_compute: bool = False
266266
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
267267

268+
ep_overlap_early_attn_memory_release: bool = False
269+
"""Enable early memory release of attention activations during EP overlap.
270+
EP overlap can increase peak memory usage when the overlapped forward module allocates
271+
more memory than what is freed by the backward module. This flag addresses this by
272+
reordering the attention backward pass to occur earlier in the schedule.
273+
Specifically:
274+
- Without this flag: attn_bwd executes after moe_combine_fwd
275+
- With this flag: attn_bwd executes before mlp_fwd
276+
The earlier execution releases attention activations sooner, reducing peak memory.
277+
Note: This may impact performance as moe_combine_fwd and moe_dispatch_bwd become
278+
exposed (not overlapped with other computation).
279+
"""
280+
268281
###################
269282
# Pipeline Parallel
270283
###################

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 54 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7777
"""
7878
from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState
7979

80+
self.config = layer.config
8081
self.layer_state = TransformerLayerState()
8182
self.chunk_state = chunk_state
8283
self.layer = layer
@@ -87,6 +88,32 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8788
# get callable nodes for transformer/mtp layer
8889
self._build_callable_nodes(event, comp_stream, comm_stream, extra_args)
8990

91+
def release_state(self):
92+
"""Release reference, this helps avoid memory leak."""
93+
if hasattr(self, 'attn') and self.attn is not None:
94+
del self.attn
95+
self.attn = None
96+
if hasattr(self, 'post_attn') and self.post_attn is not None:
97+
del self.post_attn
98+
self.post_attn = None
99+
if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None:
100+
del self.moe_dispatch
101+
self.moe_dispatch = None
102+
if hasattr(self, 'mlp') and self.mlp is not None:
103+
del self.mlp
104+
self.mlp = None
105+
if hasattr(self, 'moe_combine') and self.moe_combine is not None:
106+
del self.moe_combine
107+
self.moe_combine = None
108+
if hasattr(self, 'mtp_post_process') and self.mtp_post_process is not None:
109+
del self.mtp_post_process
110+
self.mtp_post_process = None
111+
if hasattr(self, 'layer_state') and self.layer_state is not None:
112+
del self.layer_state
113+
self.layer_state = None
114+
if hasattr(self, 'layer'):
115+
del self.layer
116+
90117
def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
91118
"""
92119
Builds the callable nodes for the transformer/mtp layer:
@@ -114,7 +141,12 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
114141
self.layer.config.moe_token_dispatcher_type == "flex"
115142
and self.layer.config.moe_flex_dispatcher_backend == "deepep"
116143
)
144+
enable_hybridep = (
145+
self.layer.config.moe_token_dispatcher_type == "flex"
146+
and self.layer.config.moe_flex_dispatcher_backend == "hybridep"
147+
)
117148
extra_args["enable_deepep"] = enable_deepep
149+
extra_args["enable_hybridep"] = enable_hybridep
118150
extra_args["is_moe"] = is_moe
119151
extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute
120152
extra_args["is_mtp"] = is_mtp
@@ -221,6 +253,10 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
221253
b_layer.mlp.backward_dw()
222254
b_grad = b_layer.moe_dispatch.backward(b_grad)
223255

256+
if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release:
257+
b_grad = b_layer.post_attn.backward(b_grad)
258+
b_grad = b_layer.attn.backward(b_grad)
259+
224260
if f_layer is not None:
225261
with f_layer.get_fp8_context():
226262
f_input = f_layer.mlp.forward(f_input)
@@ -230,7 +266,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
230266
f_input = f_layer.moe_combine.forward(f_input)
231267
f_input = f_layer.mtp_post_process.forward(f_input)
232268

233-
if b_layer is not None:
269+
if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release:
234270
b_grad = b_layer.post_attn.backward(b_grad)
235271
b_grad = b_layer.attn.backward(b_grad)
236272

@@ -372,6 +408,10 @@ def get_layer(self, i):
372408
assert i < self.num_layers()
373409
return self._transformer_layers[i]
374410

411+
def pop_layer(self):
412+
"""Pops the transformer layer in FILO order."""
413+
return self._transformer_layers.pop()
414+
375415
def num_layers(self):
376416
"""Gets the number of transformer layers."""
377417
return len(self._transformer_layers)
@@ -450,29 +490,34 @@ def run(
450490
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
451491
overlapped_layers = min(f_num_layers, b_num_layers)
452492

493+
f_layer = b_layer = None
453494
# combined forward and backward pass for overlapped layers
454495
for i in range(overlapped_layers):
455496
f_layer = f_schedule_plan.get_layer(i)
456-
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
457-
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
458497
if f_layer.layer.config.fine_grained_activation_offloading:
459498
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
499+
b_layer = b_schedule_plan.pop_layer()
500+
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
460501
f_input, b_grad = TransformerLayerSchedulePlan.run(
461502
f_layer,
462503
b_layer,
463504
f_input=f_input,
464505
b_grad=b_grad,
465506
is_last_layer_in_bwd=(i == b_num_layers - 1),
466507
)
508+
if i < b_num_layers - 1:
509+
b_layer.release_state()
467510
torch.cuda.nvtx.range_pop()
468511

469512
# backward pass for the remaining layers
470513
for i in range(overlapped_layers, b_num_layers):
471-
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
472-
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
514+
b_layer = b_schedule_plan.pop_layer()
515+
torch.cuda.nvtx.range_push(f"layer_{b_schedule_plan.num_layers()}b")
473516
_, b_grad = TransformerLayerSchedulePlan.run(
474517
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
475518
)
519+
if i < b_num_layers - 1:
520+
b_layer.release_state()
476521
torch.cuda.nvtx.range_pop()
477522

478523
# forward pass for the remaining layers
@@ -500,7 +545,9 @@ def run(
500545
# Delay the last attn_dw in backward pass (attn_dw of the first layer)
501546
# for overlapping with the p2p comm
502547
if b_num_layers > 0:
503-
b_schedule_plan.get_layer(0).attn.backward_dw()
548+
assert b_layer is not None
549+
b_layer.attn.backward_dw()
550+
b_layer.release_state()
504551

505552
# post process forward
506553
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
@@ -513,9 +560,7 @@ def run(
513560
f_schedule_plan.wait_current_stream()
514561
if b_schedule_plan:
515562
b_schedule_plan.wait_current_stream()
516-
517-
# Release reference as early as possible, this helps avoid memory leak.
518-
if b_schedule_plan is not None:
563+
# Release reference as early as possible, this helps avoid memory leak.
519564
b_schedule_plan.release_state()
520565

521566
return f_input

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
get_mtp_layer_offset,
2222
)
2323
from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor
24+
from megatron.core.utils import internal_api
2425

2526

2627
def weak_method(method):
@@ -40,13 +41,15 @@ def wrapped_func(*args, **kwarg):
4041
return wrapped_func
4142

4243

43-
def should_free_input(name, is_moe, is_deepep):
44+
@internal_api
45+
def should_free_input(name, is_moe, enable_deepep, enable_hybridep):
4446
"""Determine if the node should free its input memory.
4547
4648
Args:
4749
name: Node name
4850
is_moe: Whether it's a MoE model
49-
is_deepep: Whether it's a DeepEP model
51+
enable_deepep: Whether to use DeepEP dispatcher
52+
enable_hybridep: Whether to use HybridEP dispatcher
5053
5154
Returns:
5255
bool: Whether to free input memory
@@ -60,12 +63,13 @@ def should_free_input(name, is_moe, is_deepep):
6063
# The input and output of A2A are not needed anymore after the forward pass,
6164
# so we can free the input memory after the forward pass.
6265
free_input_nodes = {
63-
"mlp": True,
66+
"mlp": not enable_hybridep,
6467
"moe_combine": True,
65-
# For non-deepep mode, the input is the un-dispatched tokens and probs before dispatch A2A
66-
# and it's not needed anymore after the forward pass
67-
# For deepep mode, they are both needed in backward pass, so they cannot be freed.
68-
"moe_dispatch": not is_deepep,
68+
# For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens
69+
# and probs before dispatch A2A and it's not needed anymore after the forward pass
70+
# For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass
71+
# and cannot be freed.
72+
"moe_dispatch": not (enable_deepep or enable_hybridep),
6973
}
7074

7175
return free_input_nodes.get(name, False)
@@ -223,12 +227,13 @@ def __init__(
223227
it's the per_batch_state_context, o.w. nullcontext
224228
name (str): Node name, also used to determine memory strategy
225229
bwd_dw_callables (list): List of weight gradient functions for the layer.
226-
extra_args (dict): Extra arguments for the node: is_moe, enable_deepep.
230+
extra_args (dict): Extra arguments for nodes: is_moe, enable_deepep, enable_hybridep.
227231
"""
228232
# determine whether to free input memory
229233
is_moe = extra_args.get("is_moe", False)
230234
enable_deepep = extra_args.get("enable_deepep", False)
231-
free_input = should_free_input(name, is_moe, enable_deepep)
235+
enable_hybridep = extra_args.get("enable_hybridep", False)
236+
free_input = should_free_input(name, is_moe, enable_deepep, enable_hybridep)
232237
self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False)
233238

234239
super().__init__(
@@ -274,7 +279,13 @@ def backward_impl(self, outputs, output_grad):
274279
detached_grad = tuple([e.grad for e in self.detached])
275280
grads = output_grad + detached_grad
276281
self.default_backward_func(outputs + self.before_detached, grads)
277-
self._release_state()
282+
# release the output grad memory after backward finishes,
283+
# except when delay_wgrad_comptue is enabled, the grad should be
284+
# kept until all modules' backward_dw has been invoked.
285+
if self.delay_wgrad_compute:
286+
self.output_grads = grads
287+
self.delay_grads_release = len(self.bwd_dw_callables) > 0
288+
278289
# return grads for record stream
279290
return grads
280291

@@ -285,9 +296,16 @@ def backward_dw(self):
285296
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
286297
for module in self.bwd_dw_callables:
287298
module.backward_dw()
299+
300+
# the output grad memory is last used in wgrad compute, should be safe to release.
301+
assert self.delay_grads_release, "output grad memory should be valid before wgrad."
302+
for tensor in self.output_grads:
303+
tensor.untyped_storage().resize_(0)
304+
self.output_grads = None
305+
288306
self.bwd_dw_callables = None
289307

290-
def _release_state(self):
308+
def __del__(self):
291309
# Release reference as early as possible, this helps avoid memory leak.
292310
self.before_detached = None
293311
self.detached = None
@@ -328,6 +346,10 @@ def build_transformer_layer_callables(layer: TransformerLayer):
328346
layer.config.moe_token_dispatcher_type == "flex"
329347
and layer.config.moe_flex_dispatcher_backend == "deepep"
330348
)
349+
enable_hybridep = (
350+
layer.config.moe_token_dispatcher_type == "flex"
351+
and layer.config.moe_flex_dispatcher_backend == "hybridep"
352+
)
331353

332354
def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor):
333355
"""
@@ -379,7 +401,7 @@ def submodule_dispatch_forward(
379401
Dispatches tokens to the experts based on the router output.
380402
"""
381403
token_dispatcher = layer.mlp.token_dispatcher
382-
if enable_deepep:
404+
if enable_deepep or enable_hybridep:
383405
# update token_probs to be the detached version, prevents
384406
# backward graph from connecting to attn submodule
385407
token_dispatcher._comm_manager.token_probs = probs
@@ -396,7 +418,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
396418
shared_expert_output = None
397419
dispatched_probs = node.layer_state.dispatched_probs
398420
token_dispatcher = layer.mlp.token_dispatcher
399-
if enable_deepep:
421+
if enable_deepep or enable_hybridep:
400422
# update dispatched_probs to be detached version, prevents
401423
# backward graph from connecting to dispatch submodule
402424
token_dispatcher._comm_manager.dispatched_probs = dispatched_probs

megatron/core/pipeline_parallel/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def __init__(
182182
self.free_input = free_input
183183
self.inputs = None
184184
self.outputs = None
185+
self.delay_grads_release = False
185186

186187
def default_backward_func(self, outputs, output_grad):
187188
"""Default backward function"""
@@ -263,6 +264,12 @@ def _backward(self, *output_grad):
263264
for g in output_grad:
264265
if g is not None:
265266
g.record_stream(self.stream)
267+
# Manually trigger the memory release of dgrad tensor
268+
# to avoid delayed garbage collection. If
269+
# delay_grads_release is True, dgrad is last used in
270+
# wgrad compute and skip the release here.
271+
if not self.delay_grads_release:
272+
g.untyped_storage().resize_(0)
266273

267274
grads = self.get_grad()
268275
self._release_state()

megatron/core/transformer/transformer_config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,6 +1843,11 @@ def __post_init__(self):
18431843
assert (
18441844
self.mtp_num_layers is None or self.mtp_num_layers == 1
18451845
), 'MTP layernum only supports 1 when enabling overlap_moe_expert_parallel_comm.'
1846+
if self.mtp_num_layers == 1:
1847+
assert self.pipeline_model_parallel_size > 1, (
1848+
'Pipeline model parallel size must be larger than 1 '
1849+
'when enabling overlap_moe_expert_parallel_comm with MTP layer.'
1850+
)
18461851

18471852
# Check delay_wgrad_compute compatibility
18481853
if self.delay_wgrad_compute:
@@ -1853,6 +1858,12 @@ def __post_init__(self):
18531858
not self.moe_use_legacy_grouped_gemm
18541859
), 'delay_wgrad_compute is not supported with legacy groupedgemm implementation'
18551860

1861+
if self.ep_overlap_early_attn_memory_release:
1862+
assert self.overlap_moe_expert_parallel_comm, (
1863+
'overlap_moe_expert_parallel_comm must be enabled when enabling '
1864+
'ep_overlap_early_attn_memory_release'
1865+
)
1866+
18561867
if self.context_parallel_size > 1 and self.cp_comm_type is not None:
18571868
if isinstance(self.cp_comm_type, list):
18581869
assert len(self.cp_comm_type) == self.num_layers, (

megatron/training/arguments.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3348,6 +3348,8 @@ def _add_moe_args(parser):
33483348
help='Overlap the EP A2A communication by batch-level overlapping in 1f1b stage.')
33493349
group.add_argument('--delay-wgrad-compute', action='store_true',
33503350
help='Delay the wgrad compute for batch-level overlapping')
3351+
group.add_argument('--ep-overlap-early-attn-memory-release', action='store_true',
3352+
help='Release the memory of the attention module early in EP overlap.')
33513353

33523354
group.add_argument('--moe-upcycling-granularity', type=int, default=1,
33533355
help='This param sepecifics how many times smaller is the expert hidden size compared with the original dense FFN hidden size. '

0 commit comments

Comments
 (0)