Skip to content

Commit d4529eb

Browse files
committed
Revert "[Dev] fix(moe): Support HybridEP and reduce memory overhead for 1F1B A2A overlap (NVIDIA#2201)"
This reverts commit c60d5c2.
1 parent 2b1fc70 commit d4529eb

File tree

7 files changed

+22
-174
lines changed

7 files changed

+22
-174
lines changed

megatron/core/model_parallel_config.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -265,19 +265,6 @@ class ModelParallelConfig:
265265
delay_wgrad_compute: bool = False
266266
"""Delay the weight gradient computation to improve batch-level communication overlapping"""
267267

268-
ep_overlap_early_attn_memory_release: bool = False
269-
"""Enable early memory release of attention activations during EP overlap.
270-
EP overlap can increase peak memory usage when the overlapped forward module allocates
271-
more memory than what is freed by the backward module. This flag addresses this by
272-
reordering the attention backward pass to occur earlier in the schedule.
273-
Specifically:
274-
- Without this flag: attn_bwd executes after moe_combine_fwd
275-
- With this flag: attn_bwd executes before mlp_fwd
276-
The earlier execution releases attention activations sooner, reducing peak memory.
277-
Note: This may impact performance as moe_combine_fwd and moe_dispatch_bwd become
278-
exposed (not overlapped with other computation).
279-
"""
280-
281268
###################
282269
# Pipeline Parallel
283270
###################

megatron/core/models/common/model_chunk_schedule_plan.py

Lines changed: 9 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7777
"""
7878
from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState
7979

80-
self.config = layer.config
8180
self.layer_state = TransformerLayerState()
8281
self.chunk_state = chunk_state
8382
self.layer = layer
@@ -88,32 +87,6 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8887
# get callable nodes for transformer/mtp layer
8988
self._build_callable_nodes(event, comp_stream, comm_stream, extra_args)
9089

91-
def release_state(self):
92-
"""Release reference, this helps avoid memory leak."""
93-
if hasattr(self, 'attn') and self.attn is not None:
94-
del self.attn
95-
self.attn = None
96-
if hasattr(self, 'post_attn') and self.post_attn is not None:
97-
del self.post_attn
98-
self.post_attn = None
99-
if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None:
100-
del self.moe_dispatch
101-
self.moe_dispatch = None
102-
if hasattr(self, 'mlp') and self.mlp is not None:
103-
del self.mlp
104-
self.mlp = None
105-
if hasattr(self, 'moe_combine') and self.moe_combine is not None:
106-
del self.moe_combine
107-
self.moe_combine = None
108-
if hasattr(self, 'mtp_post_process') and self.mtp_post_process is not None:
109-
del self.mtp_post_process
110-
self.mtp_post_process = None
111-
if hasattr(self, 'layer_state') and self.layer_state is not None:
112-
del self.layer_state
113-
self.layer_state = None
114-
if hasattr(self, 'layer'):
115-
del self.layer
116-
11790
def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
11891
"""
11992
Builds the callable nodes for the transformer/mtp layer:
@@ -141,12 +114,7 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
141114
self.layer.config.moe_token_dispatcher_type == "flex"
142115
and self.layer.config.moe_flex_dispatcher_backend == "deepep"
143116
)
144-
enable_hybridep = (
145-
self.layer.config.moe_token_dispatcher_type == "flex"
146-
and self.layer.config.moe_flex_dispatcher_backend == "hybridep"
147-
)
148117
extra_args["enable_deepep"] = enable_deepep
149-
extra_args["enable_hybridep"] = enable_hybridep
150118
extra_args["is_moe"] = is_moe
151119
extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute
152120
extra_args["is_mtp"] = is_mtp
@@ -253,10 +221,6 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
253221
b_layer.mlp.backward_dw()
254222
b_grad = b_layer.moe_dispatch.backward(b_grad)
255223

256-
if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release:
257-
b_grad = b_layer.post_attn.backward(b_grad)
258-
b_grad = b_layer.attn.backward(b_grad)
259-
260224
if f_layer is not None:
261225
with f_layer.get_fp8_context():
262226
f_input = f_layer.mlp.forward(f_input)
@@ -266,7 +230,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
266230
f_input = f_layer.moe_combine.forward(f_input)
267231
f_input = f_layer.mtp_post_process.forward(f_input)
268232

269-
if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release:
233+
if b_layer is not None:
270234
b_grad = b_layer.post_attn.backward(b_grad)
271235
b_grad = b_layer.attn.backward(b_grad)
272236

@@ -408,10 +372,6 @@ def get_layer(self, i):
408372
assert i < self.num_layers()
409373
return self._transformer_layers[i]
410374

411-
def pop_layer(self):
412-
"""Pops the transformer layer in FILO order."""
413-
return self._transformer_layers.pop()
414-
415375
def num_layers(self):
416376
"""Gets the number of transformer layers."""
417377
return len(self._transformer_layers)
@@ -490,34 +450,29 @@ def run(
490450
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
491451
overlapped_layers = min(f_num_layers, b_num_layers)
492452

493-
f_layer = b_layer = None
494453
# combined forward and backward pass for overlapped layers
495454
for i in range(overlapped_layers):
496455
f_layer = f_schedule_plan.get_layer(i)
456+
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
457+
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
497458
if f_layer.layer.config.fine_grained_activation_offloading:
498459
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
499-
b_layer = b_schedule_plan.pop_layer()
500-
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
501460
f_input, b_grad = TransformerLayerSchedulePlan.run(
502461
f_layer,
503462
b_layer,
504463
f_input=f_input,
505464
b_grad=b_grad,
506465
is_last_layer_in_bwd=(i == b_num_layers - 1),
507466
)
508-
if i < b_num_layers - 1:
509-
b_layer.release_state()
510467
torch.cuda.nvtx.range_pop()
511468

512469
# backward pass for the remaining layers
513470
for i in range(overlapped_layers, b_num_layers):
514-
b_layer = b_schedule_plan.pop_layer()
515-
torch.cuda.nvtx.range_push(f"layer_{b_schedule_plan.num_layers()}b")
471+
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
472+
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
516473
_, b_grad = TransformerLayerSchedulePlan.run(
517474
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
518475
)
519-
if i < b_num_layers - 1:
520-
b_layer.release_state()
521476
torch.cuda.nvtx.range_pop()
522477

523478
# forward pass for the remaining layers
@@ -545,9 +500,7 @@ def run(
545500
# Delay the last attn_dw in backward pass (attn_dw of the first layer)
546501
# for overlapping with the p2p comm
547502
if b_num_layers > 0:
548-
assert b_layer is not None
549-
b_layer.attn.backward_dw()
550-
b_layer.release_state()
503+
b_schedule_plan.get_layer(0).attn.backward_dw()
551504

552505
# post process forward
553506
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
@@ -560,7 +513,9 @@ def run(
560513
f_schedule_plan.wait_current_stream()
561514
if b_schedule_plan:
562515
b_schedule_plan.wait_current_stream()
563-
# Release reference as early as possible, this helps avoid memory leak.
516+
517+
# Release reference as early as possible, this helps avoid memory leak.
518+
if b_schedule_plan is not None:
564519
b_schedule_plan.release_state()
565520

566521
return f_input

megatron/core/models/gpt/fine_grained_callables.py

Lines changed: 13 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
get_mtp_layer_offset,
2222
)
2323
from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor
24-
from megatron.core.utils import internal_api
2524

2625

2726
def weak_method(method):
@@ -41,15 +40,13 @@ def wrapped_func(*args, **kwarg):
4140
return wrapped_func
4241

4342

44-
@internal_api
45-
def should_free_input(name, is_moe, enable_deepep, enable_hybridep):
43+
def should_free_input(name, is_moe, is_deepep):
4644
"""Determine if the node should free its input memory.
4745
4846
Args:
4947
name: Node name
5048
is_moe: Whether it's a MoE model
51-
enable_deepep: Whether to use DeepEP dispatcher
52-
enable_hybridep: Whether to use HybridEP dispatcher
49+
is_deepep: Whether it's a DeepEP model
5350
5451
Returns:
5552
bool: Whether to free input memory
@@ -63,13 +60,12 @@ def should_free_input(name, is_moe, enable_deepep, enable_hybridep):
6360
# The input and output of A2A are not needed anymore after the forward pass,
6461
# so we can free the input memory after the forward pass.
6562
free_input_nodes = {
66-
"mlp": not enable_hybridep,
63+
"mlp": True,
6764
"moe_combine": True,
68-
# For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens
69-
# and probs before dispatch A2A and it's not needed anymore after the forward pass
70-
# For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass
71-
# and cannot be freed.
72-
"moe_dispatch": not (enable_deepep or enable_hybridep),
65+
# For non-deepep mode, the input is the un-dispatched tokens and probs before dispatch A2A
66+
# and it's not needed anymore after the forward pass
67+
# For deepep mode, they are both needed in backward pass, so they cannot be freed.
68+
"moe_dispatch": not is_deepep,
7369
}
7470

7571
return free_input_nodes.get(name, False)
@@ -227,13 +223,12 @@ def __init__(
227223
it's the per_batch_state_context, o.w. nullcontext
228224
name (str): Node name, also used to determine memory strategy
229225
bwd_dw_callables (list): List of weight gradient functions for the layer.
230-
extra_args (dict): Extra arguments for nodes: is_moe, enable_deepep, enable_hybridep.
226+
extra_args (dict): Extra arguments for the node: is_moe, enable_deepep.
231227
"""
232228
# determine whether to free input memory
233229
is_moe = extra_args.get("is_moe", False)
234230
enable_deepep = extra_args.get("enable_deepep", False)
235-
enable_hybridep = extra_args.get("enable_hybridep", False)
236-
free_input = should_free_input(name, is_moe, enable_deepep, enable_hybridep)
231+
free_input = should_free_input(name, is_moe, enable_deepep)
237232
self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False)
238233

239234
super().__init__(
@@ -279,13 +274,7 @@ def backward_impl(self, outputs, output_grad):
279274
detached_grad = tuple([e.grad for e in self.detached])
280275
grads = output_grad + detached_grad
281276
self.default_backward_func(outputs + self.before_detached, grads)
282-
# release the output grad memory after backward finishes,
283-
# except when delay_wgrad_comptue is enabled, the grad should be
284-
# kept until all modules' backward_dw has been invoked.
285-
if self.delay_wgrad_compute:
286-
self.output_grads = grads
287-
self.delay_grads_release = len(self.bwd_dw_callables) > 0
288-
277+
self._release_state()
289278
# return grads for record stream
290279
return grads
291280

@@ -296,16 +285,9 @@ def backward_dw(self):
296285
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
297286
for module in self.bwd_dw_callables:
298287
module.backward_dw()
299-
300-
# the output grad memory is last used in wgrad compute, should be safe to release.
301-
assert self.delay_grads_release, "output grad memory should be valid before wgrad."
302-
for tensor in self.output_grads:
303-
tensor.untyped_storage().resize_(0)
304-
self.output_grads = None
305-
306288
self.bwd_dw_callables = None
307289

308-
def __del__(self):
290+
def _release_state(self):
309291
# Release reference as early as possible, this helps avoid memory leak.
310292
self.before_detached = None
311293
self.detached = None
@@ -346,10 +328,6 @@ def build_transformer_layer_callables(layer: TransformerLayer):
346328
layer.config.moe_token_dispatcher_type == "flex"
347329
and layer.config.moe_flex_dispatcher_backend == "deepep"
348330
)
349-
enable_hybridep = (
350-
layer.config.moe_token_dispatcher_type == "flex"
351-
and layer.config.moe_flex_dispatcher_backend == "hybridep"
352-
)
353331

354332
def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor):
355333
"""
@@ -401,7 +379,7 @@ def submodule_dispatch_forward(
401379
Dispatches tokens to the experts based on the router output.
402380
"""
403381
token_dispatcher = layer.mlp.token_dispatcher
404-
if enable_deepep or enable_hybridep:
382+
if enable_deepep:
405383
# update token_probs to be the detached version, prevents
406384
# backward graph from connecting to attn submodule
407385
token_dispatcher._comm_manager.token_probs = probs
@@ -418,7 +396,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
418396
shared_expert_output = None
419397
dispatched_probs = node.layer_state.dispatched_probs
420398
token_dispatcher = layer.mlp.token_dispatcher
421-
if enable_deepep or enable_hybridep:
399+
if enable_deepep:
422400
# update dispatched_probs to be detached version, prevents
423401
# backward graph from connecting to dispatch submodule
424402
token_dispatcher._comm_manager.dispatched_probs = dispatched_probs

megatron/core/pipeline_parallel/utils.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,6 @@ def __init__(
182182
self.free_input = free_input
183183
self.inputs = None
184184
self.outputs = None
185-
self.delay_grads_release = False
186185

187186
def default_backward_func(self, outputs, output_grad):
188187
"""Default backward function"""
@@ -264,12 +263,6 @@ def _backward(self, *output_grad):
264263
for g in output_grad:
265264
if g is not None:
266265
g.record_stream(self.stream)
267-
# Manually trigger the memory release of dgrad tensor
268-
# to avoid delayed garbage collection. If
269-
# delay_grads_release is True, dgrad is last used in
270-
# wgrad compute and skip the release here.
271-
if not self.delay_grads_release:
272-
g.untyped_storage().resize_(0)
273266

274267
grads = self.get_grad()
275268
self._release_state()

megatron/core/transformer/transformer_config.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1843,11 +1843,6 @@ def __post_init__(self):
18431843
assert (
18441844
self.mtp_num_layers is None or self.mtp_num_layers == 1
18451845
), 'MTP layernum only supports 1 when enabling overlap_moe_expert_parallel_comm.'
1846-
if self.mtp_num_layers == 1:
1847-
assert self.pipeline_model_parallel_size > 1, (
1848-
'Pipeline model parallel size must be larger than 1 '
1849-
'when enabling overlap_moe_expert_parallel_comm with MTP layer.'
1850-
)
18511846

18521847
# Check delay_wgrad_compute compatibility
18531848
if self.delay_wgrad_compute:
@@ -1858,12 +1853,6 @@ def __post_init__(self):
18581853
not self.moe_use_legacy_grouped_gemm
18591854
), 'delay_wgrad_compute is not supported with legacy groupedgemm implementation'
18601855

1861-
if self.ep_overlap_early_attn_memory_release:
1862-
assert self.overlap_moe_expert_parallel_comm, (
1863-
'overlap_moe_expert_parallel_comm must be enabled when enabling '
1864-
'ep_overlap_early_attn_memory_release'
1865-
)
1866-
18671856
if self.context_parallel_size > 1 and self.cp_comm_type is not None:
18681857
if isinstance(self.cp_comm_type, list):
18691858
assert len(self.cp_comm_type) == self.num_layers, (

megatron/training/arguments.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3302,8 +3302,6 @@ def _add_moe_args(parser):
33023302
help='Overlap the EP A2A communication by batch-level overlapping in 1f1b stage.')
33033303
group.add_argument('--delay-wgrad-compute', action='store_true',
33043304
help='Delay the wgrad compute for batch-level overlapping')
3305-
group.add_argument('--ep-overlap-early-attn-memory-release', action='store_true',
3306-
help='Release the memory of the attention module early in EP overlap.')
33073305

33083306
group.add_argument('--moe-upcycling-granularity', type=int, default=1,
33093307
help='This param sepecifics how many times smaller is the expert hidden size compared with the original dense FFN hidden size. '

0 commit comments

Comments
 (0)