Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
12afb8b
release unused memory
lhb8125 Sep 3, 2025
ab40d7b
format
lhb8125 Sep 4, 2025
0e641d3
Merge branch 'main_tot' into hongbinl/1f1b_overlap_memory_issue
lhb8125 Sep 4, 2025
1219a26
renaming golden values
lhb8125 Oct 29, 2025
ce6e661
fix bug: accuracy issu because of recomputing and offloading same module
lhb8125 Nov 4, 2025
d04d741
Merge branch 'dev' into hongbinl/activation_offloading_fix
lhb8125 Nov 4, 2025
2fe4aeb
format
lhb8125 Nov 4, 2025
fb3f7c3
update golden values
lhb8125 Nov 5, 2025
5001e2b
Merge branch 'dev' into hongbinl/activation_offloading_fix
lhb8125 Nov 5, 2025
9937890
update golden values
lhb8125 Nov 5, 2025
6c83118
update model_config and golden values
lhb8125 Nov 6, 2025
33a38f5
format
lhb8125 Nov 6, 2025
6c76b07
update golden values
lhb8125 Nov 6, 2025
e8c0eb0
support hybridep+a2a overlap
lhb8125 Nov 10, 2025
b207de3
Merge branch 'dev' into hongbinl/1f1b_hybridep
Nov 11, 2025
465f497
Merge branch 'hongbinl/1f1b_overlap_memory_issue' into hongbinl/1f1b_…
lhb8125 Nov 11, 2025
299df02
minor fix
lhb8125 Nov 11, 2025
dc0cb6c
assert PP>1 for a2a overlap with MTP layers
lhb8125 Nov 11, 2025
32fc988
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Nov 12, 2025
6102cc5
Merge branch 'dev' into hongbinl/1f1b_hybridep
yanring Nov 18, 2025
d29b634
Merge branch 'dev' into hongbinl/1f1b_hybridep
yanring Nov 20, 2025
5518940
revert the changes about memory overhead optimization
lhb8125 Nov 27, 2025
9dca28b
Merge branch 'hongbinl/1f1b_hybridep' of https://github.com/lhb8125/M…
lhb8125 Nov 27, 2025
e0e6da1
minor fix
lhb8125 Nov 27, 2025
04199ce
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Nov 27, 2025
448f035
fix back compatibility
lhb8125 Nov 27, 2025
e9e662b
Merge branch 'hongbinl/1f1b_hybridep' of https://github.com/lhb8125/M…
lhb8125 Nov 27, 2025
fe05568
format
lhb8125 Nov 27, 2025
b83deee
support early attn mem replease
Wohox Dec 1, 2025
f441acf
fix mem opt ut
Wohox Dec 1, 2025
8c2e9ad
format
Wohox Dec 1, 2025
f1c886c
Merge pull request #48 from Wohox/pingtian/add_ep_overlap_switch_orde…
lhb8125 Dec 2, 2025
bc209bb
fix bugs when enabling hybridep
lhb8125 Dec 2, 2025
060f53d
format
lhb8125 Dec 2, 2025
776d224
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Dec 2, 2025
487eea9
remove unused try-except clause
lhb8125 Dec 2, 2025
c568c37
format
lhb8125 Dec 2, 2025
36648e3
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Dec 5, 2025
33d4d9c
fix comments
Wohox Dec 5, 2025
3ee932b
more explanation
Wohox Dec 5, 2025
61a75d2
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Dec 8, 2025
0708cc1
replace __del__ with explicit destructor
lhb8125 Dec 8, 2025
2cfaec1
format
lhb8125 Dec 8, 2025
0e299f6
Merge branch 'dev' into hongbinl/1f1b_hybridep
lhb8125 Dec 8, 2025
0f8663b
fix ut
lhb8125 Dec 8, 2025
97de523
Merge pull request #50 from Wohox/pingtian/fix_comments_2201
lhb8125 Dec 8, 2025
12a2a22
fix ut
lhb8125 Dec 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions megatron/core/model_parallel_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,19 @@ class ModelParallelConfig:
delay_wgrad_compute: bool = False
"""Delay the weight gradient computation to improve batch-level communication overlapping"""

ep_overlap_early_attn_memory_release: bool = False
"""Enable early memory release of attention activations during EP overlap.
EP overlap can increase peak memory usage when the overlapped forward module allocates
more memory than what is freed by the backward module. This flag addresses this by
reordering the attention backward pass to occur earlier in the schedule.
Specifically:
- Without this flag: attn_bwd executes after moe_combine_fwd
- With this flag: attn_bwd executes before mlp_fwd
The earlier execution releases attention activations sooner, reducing peak memory.
Note: This may impact performance as moe_combine_fwd and moe_dispatch_bwd become
exposed (not overlapped with other computation).
"""

###################
# Pipeline Parallel
###################
Expand Down
63 changes: 54 additions & 9 deletions megatron/core/models/common/model_chunk_schedule_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
"""
from megatron.core.models.gpt.fine_grained_callables import TransformerLayerState

self.config = layer.config
self.layer_state = TransformerLayerState()
self.chunk_state = chunk_state
self.layer = layer
Expand All @@ -87,6 +88,32 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
# get callable nodes for transformer/mtp layer
self._build_callable_nodes(event, comp_stream, comm_stream, extra_args)

def release_state(self):
"""Release reference, this helps avoid memory leak."""
if hasattr(self, 'attn') and self.attn is not None:
del self.attn
self.attn = None
if hasattr(self, 'post_attn') and self.post_attn is not None:
del self.post_attn
self.post_attn = None
if hasattr(self, 'moe_dispatch') and self.moe_dispatch is not None:
del self.moe_dispatch
self.moe_dispatch = None
if hasattr(self, 'mlp') and self.mlp is not None:
del self.mlp
self.mlp = None
if hasattr(self, 'moe_combine') and self.moe_combine is not None:
del self.moe_combine
self.moe_combine = None
if hasattr(self, 'mtp_post_process') and self.mtp_post_process is not None:
del self.mtp_post_process
self.mtp_post_process = None
if hasattr(self, 'layer_state') and self.layer_state is not None:
del self.layer_state
self.layer_state = None
if hasattr(self, 'layer'):
del self.layer

def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
"""
Builds the callable nodes for the transformer/mtp layer:
Expand Down Expand Up @@ -114,7 +141,12 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
self.layer.config.moe_token_dispatcher_type == "flex"
and self.layer.config.moe_flex_dispatcher_backend == "deepep"
)
enable_hybridep = (
self.layer.config.moe_token_dispatcher_type == "flex"
and self.layer.config.moe_flex_dispatcher_backend == "hybridep"
)
extra_args["enable_deepep"] = enable_deepep
extra_args["enable_hybridep"] = enable_hybridep
extra_args["is_moe"] = is_moe
extra_args["delay_wgrad_compute"] = self.layer.config.delay_wgrad_compute
extra_args["is_mtp"] = is_mtp
Expand Down Expand Up @@ -221,6 +253,10 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
b_layer.mlp.backward_dw()
b_grad = b_layer.moe_dispatch.backward(b_grad)

if b_layer is not None and b_layer.config.ep_overlap_early_attn_memory_release:
b_grad = b_layer.post_attn.backward(b_grad)
b_grad = b_layer.attn.backward(b_grad)

if f_layer is not None:
with f_layer.get_fp8_context():
f_input = f_layer.mlp.forward(f_input)
Expand All @@ -230,7 +266,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
f_input = f_layer.moe_combine.forward(f_input)
f_input = f_layer.mtp_post_process.forward(f_input)

if b_layer is not None:
if b_layer is not None and not b_layer.config.ep_overlap_early_attn_memory_release:
b_grad = b_layer.post_attn.backward(b_grad)
b_grad = b_layer.attn.backward(b_grad)

Expand Down Expand Up @@ -372,6 +408,10 @@ def get_layer(self, i):
assert i < self.num_layers()
return self._transformer_layers[i]

def pop_layer(self):
"""Pops the transformer layer in FILO order."""
return self._transformer_layers.pop()

def num_layers(self):
"""Gets the number of transformer layers."""
return len(self._transformer_layers)
Expand Down Expand Up @@ -450,29 +490,34 @@ def run(
b_num_layers = b_schedule_plan.num_layers() if b_schedule_plan is not None else 0
overlapped_layers = min(f_num_layers, b_num_layers)

f_layer = b_layer = None
# combined forward and backward pass for overlapped layers
for i in range(overlapped_layers):
f_layer = f_schedule_plan.get_layer(i)
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_num_layers - 1 - i}b")
if f_layer.layer.config.fine_grained_activation_offloading:
fine_grained_offloading_set_last_layer(i == f_num_layers - 1)
b_layer = b_schedule_plan.pop_layer()
torch.cuda.nvtx.range_push(f"layer_{i}f-layer_{b_schedule_plan.num_layers()}b")
f_input, b_grad = TransformerLayerSchedulePlan.run(
f_layer,
b_layer,
f_input=f_input,
b_grad=b_grad,
is_last_layer_in_bwd=(i == b_num_layers - 1),
)
if i < b_num_layers - 1:
b_layer.release_state()
torch.cuda.nvtx.range_pop()

# backward pass for the remaining layers
for i in range(overlapped_layers, b_num_layers):
b_layer = b_schedule_plan.get_layer(b_num_layers - 1 - i)
torch.cuda.nvtx.range_push(f"layer_{b_num_layers - 1 - i}b")
b_layer = b_schedule_plan.pop_layer()
torch.cuda.nvtx.range_push(f"layer_{b_schedule_plan.num_layers()}b")
_, b_grad = TransformerLayerSchedulePlan.run(
None, b_layer, b_grad=b_grad, is_last_layer_in_bwd=(i == b_num_layers - 1)
)
if i < b_num_layers - 1:
b_layer.release_state()
torch.cuda.nvtx.range_pop()

# forward pass for the remaining layers
Expand Down Expand Up @@ -500,7 +545,9 @@ def run(
# Delay the last attn_dw in backward pass (attn_dw of the first layer)
# for overlapping with the p2p comm
if b_num_layers > 0:
b_schedule_plan.get_layer(0).attn.backward_dw()
assert b_layer is not None
b_layer.attn.backward_dw()
b_layer.release_state()

# post process forward
if f_schedule_plan is not None and f_schedule_plan.post_process is not None:
Expand All @@ -513,9 +560,7 @@ def run(
f_schedule_plan.wait_current_stream()
if b_schedule_plan:
b_schedule_plan.wait_current_stream()

# Release reference as early as possible, this helps avoid memory leak.
if b_schedule_plan is not None:
# Release reference as early as possible, this helps avoid memory leak.
b_schedule_plan.release_state()

return f_input
48 changes: 35 additions & 13 deletions megatron/core/models/gpt/fine_grained_callables.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
get_mtp_layer_offset,
)
from megatron.core.transformer.transformer_layer import TransformerLayer, make_viewless_tensor
from megatron.core.utils import internal_api


def weak_method(method):
Expand All @@ -40,13 +41,15 @@ def wrapped_func(*args, **kwarg):
return wrapped_func


def should_free_input(name, is_moe, is_deepep):
@internal_api
def should_free_input(name, is_moe, enable_deepep, enable_hybridep):
"""Determine if the node should free its input memory.

Args:
name: Node name
is_moe: Whether it's a MoE model
is_deepep: Whether it's a DeepEP model
enable_deepep: Whether to use DeepEP dispatcher
enable_hybridep: Whether to use HybridEP dispatcher

Returns:
bool: Whether to free input memory
Expand All @@ -60,12 +63,13 @@ def should_free_input(name, is_moe, is_deepep):
# The input and output of A2A are not needed anymore after the forward pass,
# so we can free the input memory after the forward pass.
free_input_nodes = {
"mlp": True,
"mlp": not enable_hybridep,
"moe_combine": True,
# For non-deepep mode, the input is the un-dispatched tokens and probs before dispatch A2A
# and it's not needed anymore after the forward pass
# For deepep mode, they are both needed in backward pass, so they cannot be freed.
"moe_dispatch": not is_deepep,
# For non-DeepEP and non-HybridEP dispatcher mode, the input is the un-dispatched tokens
# and probs before dispatch A2A and it's not needed anymore after the forward pass
# For DeepEP and HybridEP dispatcher mode, they are both needed in backward pass
# and cannot be freed.
"moe_dispatch": not (enable_deepep or enable_hybridep),
}

return free_input_nodes.get(name, False)
Expand Down Expand Up @@ -223,12 +227,13 @@ def __init__(
it's the per_batch_state_context, o.w. nullcontext
name (str): Node name, also used to determine memory strategy
bwd_dw_callables (list): List of weight gradient functions for the layer.
extra_args (dict): Extra arguments for the node: is_moe, enable_deepep.
extra_args (dict): Extra arguments for nodes: is_moe, enable_deepep, enable_hybridep.
"""
# determine whether to free input memory
is_moe = extra_args.get("is_moe", False)
enable_deepep = extra_args.get("enable_deepep", False)
free_input = should_free_input(name, is_moe, enable_deepep)
enable_hybridep = extra_args.get("enable_hybridep", False)
free_input = should_free_input(name, is_moe, enable_deepep, enable_hybridep)
self.delay_wgrad_compute = extra_args.get("delay_wgrad_compute", False)

super().__init__(
Expand Down Expand Up @@ -274,7 +279,13 @@ def backward_impl(self, outputs, output_grad):
detached_grad = tuple([e.grad for e in self.detached])
grads = output_grad + detached_grad
self.default_backward_func(outputs + self.before_detached, grads)
self._release_state()
# release the output grad memory after backward finishes,
# except when delay_wgrad_comptue is enabled, the grad should be
# kept until all modules' backward_dw has been invoked.
if self.delay_wgrad_compute:
self.output_grads = grads
self.delay_grads_release = len(self.bwd_dw_callables) > 0

# return grads for record stream
return grads

Expand All @@ -285,9 +296,16 @@ def backward_dw(self):
with torch.cuda.nvtx.range(f"{self.name} wgrad"):
for module in self.bwd_dw_callables:
module.backward_dw()

# the output grad memory is last used in wgrad compute, should be safe to release.
assert self.delay_grads_release, "output grad memory should be valid before wgrad."
for tensor in self.output_grads:
tensor.untyped_storage().resize_(0)
self.output_grads = None

self.bwd_dw_callables = None

def _release_state(self):
def __del__(self):
# Release reference as early as possible, this helps avoid memory leak.
self.before_detached = None
self.detached = None
Expand Down Expand Up @@ -328,6 +346,10 @@ def build_transformer_layer_callables(layer: TransformerLayer):
layer.config.moe_token_dispatcher_type == "flex"
and layer.config.moe_flex_dispatcher_backend == "deepep"
)
enable_hybridep = (
layer.config.moe_token_dispatcher_type == "flex"
and layer.config.moe_flex_dispatcher_backend == "hybridep"
)

def submodule_attn_forward(node: ScheduleNode, hidden_states: torch.Tensor):
"""
Expand Down Expand Up @@ -379,7 +401,7 @@ def submodule_dispatch_forward(
Dispatches tokens to the experts based on the router output.
"""
token_dispatcher = layer.mlp.token_dispatcher
if enable_deepep:
if enable_deepep or enable_hybridep:
# update token_probs to be the detached version, prevents
# backward graph from connecting to attn submodule
token_dispatcher._comm_manager.token_probs = probs
Expand All @@ -396,7 +418,7 @@ def submodule_moe_forward(node: ScheduleNode, dispatched_tokens: torch.Tensor):
shared_expert_output = None
dispatched_probs = node.layer_state.dispatched_probs
token_dispatcher = layer.mlp.token_dispatcher
if enable_deepep:
if enable_deepep or enable_hybridep:
# update dispatched_probs to be detached version, prevents
# backward graph from connecting to dispatch submodule
token_dispatcher._comm_manager.dispatched_probs = dispatched_probs
Expand Down
7 changes: 7 additions & 0 deletions megatron/core/pipeline_parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def __init__(
self.free_input = free_input
self.inputs = None
self.outputs = None
self.delay_grads_release = False

def default_backward_func(self, outputs, output_grad):
"""Default backward function"""
Expand Down Expand Up @@ -263,6 +264,12 @@ def _backward(self, *output_grad):
for g in output_grad:
if g is not None:
g.record_stream(self.stream)
# Manually trigger the memory release of dgrad tensor
# to avoid delayed garbage collection. If
# delay_grads_release is True, dgrad is last used in
# wgrad compute and skip the release here.
if not self.delay_grads_release:
g.untyped_storage().resize_(0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add some explanation here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in lhb8125#50, @lhb8125 can you help take a look~

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged, thanks!


grads = self.get_grad()
self._release_state()
Expand Down
11 changes: 11 additions & 0 deletions megatron/core/transformer/transformer_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1843,6 +1843,11 @@ def __post_init__(self):
assert (
self.mtp_num_layers is None or self.mtp_num_layers == 1
), 'MTP layernum only supports 1 when enabling overlap_moe_expert_parallel_comm.'
if self.mtp_num_layers == 1:
assert self.pipeline_model_parallel_size > 1, (
'Pipeline model parallel size must be larger than 1 '
'when enabling overlap_moe_expert_parallel_comm with MTP layer.'
)

# Check delay_wgrad_compute compatibility
if self.delay_wgrad_compute:
Expand All @@ -1853,6 +1858,12 @@ def __post_init__(self):
not self.moe_use_legacy_grouped_gemm
), 'delay_wgrad_compute is not supported with legacy groupedgemm implementation'

if self.ep_overlap_early_attn_memory_release:
assert self.overlap_moe_expert_parallel_comm, (
'overlap_moe_expert_parallel_comm must be enabled when enabling '
'ep_overlap_early_attn_memory_release'
)

if self.context_parallel_size > 1 and self.cp_comm_type is not None:
if isinstance(self.cp_comm_type, list):
assert len(self.cp_comm_type) == self.num_layers, (
Expand Down
2 changes: 2 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3348,6 +3348,8 @@ def _add_moe_args(parser):
help='Overlap the EP A2A communication by batch-level overlapping in 1f1b stage.')
group.add_argument('--delay-wgrad-compute', action='store_true',
help='Delay the wgrad compute for batch-level overlapping')
group.add_argument('--ep-overlap-early-attn-memory-release', action='store_true',
help='Release the memory of the attention module early in EP overlap.')

group.add_argument('--moe-upcycling-granularity', type=int, default=1,
help='This param sepecifics how many times smaller is the expert hidden size compared with the original dense FFN hidden size. '
Expand Down
Loading
Loading