@@ -75,6 +75,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7575 """
7676 from megatron .core .models .gpt .fine_grained_callables import TransformerLayerState
7777
78+ self .config = layer .config
7879 self .layer_state = TransformerLayerState ()
7980 self .chunk_state = chunk_state
8081 self .layer = layer
@@ -85,6 +86,32 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8586 # get callable nodes for transformer/mtp layer
8687 self ._build_callable_nodes (event , comp_stream , comm_stream , extra_args )
8788
89+ def release_state (self ):
90+ """Release reference, this helps avoid memory leak."""
91+ if hasattr (self , 'attn' ) and self .attn is not None :
92+ del self .attn
93+ self .attn = None
94+ if hasattr (self , 'post_attn' ) and self .post_attn is not None :
95+ del self .post_attn
96+ self .post_attn = None
97+ if hasattr (self , 'moe_dispatch' ) and self .moe_dispatch is not None :
98+ del self .moe_dispatch
99+ self .moe_dispatch = None
100+ if hasattr (self , 'mlp' ) and self .mlp is not None :
101+ del self .mlp
102+ self .mlp = None
103+ if hasattr (self , 'moe_combine' ) and self .moe_combine is not None :
104+ del self .moe_combine
105+ self .moe_combine = None
106+ if hasattr (self , 'mtp_post_process' ) and self .mtp_post_process is not None :
107+ del self .mtp_post_process
108+ self .mtp_post_process = None
109+ if hasattr (self , 'layer_state' ) and self .layer_state is not None :
110+ del self .layer_state
111+ self .layer_state = None
112+ if hasattr (self , 'layer' ):
113+ del self .layer
114+
88115 def _build_callable_nodes (self , event , comp_stream , comm_stream , extra_args ):
89116 """
90117 Builds the callable nodes for the transformer/mtp layer:
@@ -112,7 +139,12 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
112139 self .layer .config .moe_token_dispatcher_type == "flex"
113140 and self .layer .config .moe_flex_dispatcher_backend == "deepep"
114141 )
142+ enable_hybridep = (
143+ self .layer .config .moe_token_dispatcher_type == "flex"
144+ and self .layer .config .moe_flex_dispatcher_backend == "hybridep"
145+ )
115146 extra_args ["enable_deepep" ] = enable_deepep
147+ extra_args ["enable_hybridep" ] = enable_hybridep
116148 extra_args ["is_moe" ] = is_moe
117149 extra_args ["delay_wgrad_compute" ] = self .layer .config .delay_wgrad_compute
118150 extra_args ["is_mtp" ] = is_mtp
@@ -219,6 +251,10 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
219251 b_layer .mlp .backward_dw ()
220252 b_grad = b_layer .moe_dispatch .backward (b_grad )
221253
254+ if b_layer is not None and b_layer .config .ep_overlap_early_attn_memory_release :
255+ b_grad = b_layer .post_attn .backward (b_grad )
256+ b_grad = b_layer .attn .backward (b_grad )
257+
222258 if f_layer is not None :
223259 with f_layer .get_fp8_context ():
224260 f_input = f_layer .mlp .forward (f_input )
@@ -228,7 +264,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
228264 f_input = f_layer .moe_combine .forward (f_input )
229265 f_input = f_layer .mtp_post_process .forward (f_input )
230266
231- if b_layer is not None :
267+ if b_layer is not None and not b_layer . config . ep_overlap_early_attn_memory_release :
232268 b_grad = b_layer .post_attn .backward (b_grad )
233269 b_grad = b_layer .attn .backward (b_grad )
234270
@@ -367,6 +403,10 @@ def get_layer(self, i):
367403 assert i < self .num_layers ()
368404 return self ._transformer_layers [i ]
369405
406+ def pop_layer (self ):
407+ """Pops the transformer layer in FILO order."""
408+ return self ._transformer_layers .pop ()
409+
370410 def num_layers (self ):
371411 """Gets the number of transformer layers."""
372412 return len (self ._transformer_layers )
@@ -445,27 +485,32 @@ def run(
445485 b_num_layers = b_schedule_plan .num_layers () if b_schedule_plan is not None else 0
446486 overlapped_layers = min (f_num_layers , b_num_layers )
447487
488+ f_layer = b_layer = None
448489 # combined forward and backward pass for overlapped layers
449490 for i in range (overlapped_layers ):
450491 f_layer = f_schedule_plan .get_layer (i )
451- b_layer = b_schedule_plan .get_layer ( b_num_layers - 1 - i )
452- torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_num_layers - 1 - i } b" )
492+ b_layer = b_schedule_plan .pop_layer ( )
493+ torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_schedule_plan . num_layers () } b" )
453494 f_input , b_grad = TransformerLayerSchedulePlan .run (
454495 f_layer ,
455496 b_layer ,
456497 f_input = f_input ,
457498 b_grad = b_grad ,
458499 is_last_layer_in_bwd = (i == b_num_layers - 1 ),
459500 )
501+ if i < b_num_layers - 1 :
502+ b_layer .release_state ()
460503 torch .cuda .nvtx .range_pop ()
461504
462505 # backward pass for the remaining layers
463506 for i in range (overlapped_layers , b_num_layers ):
464- b_layer = b_schedule_plan .get_layer ( b_num_layers - 1 - i )
465- torch .cuda .nvtx .range_push (f"layer_{ b_num_layers - 1 - i } b" )
507+ b_layer = b_schedule_plan .pop_layer ( )
508+ torch .cuda .nvtx .range_push (f"layer_{ b_schedule_plan . num_layers () } b" )
466509 _ , b_grad = TransformerLayerSchedulePlan .run (
467510 None , b_layer , b_grad = b_grad , is_last_layer_in_bwd = (i == b_num_layers - 1 )
468511 )
512+ if i < b_num_layers - 1 :
513+ b_layer .release_state ()
469514 torch .cuda .nvtx .range_pop ()
470515
471516 # forward pass for the remaining layers
@@ -491,7 +536,9 @@ def run(
491536 # Delay the last attn_dw in backward pass (attn_dw of the first layer)
492537 # for overlapping with the p2p comm
493538 if b_num_layers > 0 :
494- b_schedule_plan .get_layer (0 ).attn .backward_dw ()
539+ assert b_layer is not None
540+ b_layer .attn .backward_dw ()
541+ b_layer .release_state ()
495542
496543 # post process forward
497544 if f_schedule_plan is not None and f_schedule_plan .post_process is not None :
@@ -504,9 +551,7 @@ def run(
504551 f_schedule_plan .wait_current_stream ()
505552 if b_schedule_plan :
506553 b_schedule_plan .wait_current_stream ()
507-
508- # Release reference as early as possible, this helps avoid memory leak.
509- if b_schedule_plan is not None :
554+ # Release reference as early as possible, this helps avoid memory leak.
510555 b_schedule_plan .release_state ()
511556
512557 return f_input
0 commit comments