@@ -77,6 +77,7 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7777 """
7878 from megatron .core .models .gpt .fine_grained_callables import TransformerLayerState
7979
80+ self .config = layer .config
8081 self .layer_state = TransformerLayerState ()
8182 self .chunk_state = chunk_state
8283 self .layer = layer
@@ -87,6 +88,32 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8788 # get callable nodes for transformer/mtp layer
8889 self ._build_callable_nodes (event , comp_stream , comm_stream , extra_args )
8990
91+ def release_state (self ):
92+ """Release reference, this helps avoid memory leak."""
93+ if hasattr (self , 'attn' ) and self .attn is not None :
94+ del self .attn
95+ self .attn = None
96+ if hasattr (self , 'post_attn' ) and self .post_attn is not None :
97+ del self .post_attn
98+ self .post_attn = None
99+ if hasattr (self , 'moe_dispatch' ) and self .moe_dispatch is not None :
100+ del self .moe_dispatch
101+ self .moe_dispatch = None
102+ if hasattr (self , 'mlp' ) and self .mlp is not None :
103+ del self .mlp
104+ self .mlp = None
105+ if hasattr (self , 'moe_combine' ) and self .moe_combine is not None :
106+ del self .moe_combine
107+ self .moe_combine = None
108+ if hasattr (self , 'mtp_post_process' ) and self .mtp_post_process is not None :
109+ del self .mtp_post_process
110+ self .mtp_post_process = None
111+ if hasattr (self , 'layer_state' ) and self .layer_state is not None :
112+ del self .layer_state
113+ self .layer_state = None
114+ if hasattr (self , 'layer' ):
115+ del self .layer
116+
90117 def _build_callable_nodes (self , event , comp_stream , comm_stream , extra_args ):
91118 """
92119 Builds the callable nodes for the transformer/mtp layer:
@@ -114,7 +141,12 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
114141 self .layer .config .moe_token_dispatcher_type == "flex"
115142 and self .layer .config .moe_flex_dispatcher_backend == "deepep"
116143 )
144+ enable_hybridep = (
145+ self .layer .config .moe_token_dispatcher_type == "flex"
146+ and self .layer .config .moe_flex_dispatcher_backend == "hybridep"
147+ )
117148 extra_args ["enable_deepep" ] = enable_deepep
149+ extra_args ["enable_hybridep" ] = enable_hybridep
118150 extra_args ["is_moe" ] = is_moe
119151 extra_args ["delay_wgrad_compute" ] = self .layer .config .delay_wgrad_compute
120152 extra_args ["is_mtp" ] = is_mtp
@@ -221,6 +253,10 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
221253 b_layer .mlp .backward_dw ()
222254 b_grad = b_layer .moe_dispatch .backward (b_grad )
223255
256+ if b_layer is not None and b_layer .config .ep_overlap_early_attn_memory_release :
257+ b_grad = b_layer .post_attn .backward (b_grad )
258+ b_grad = b_layer .attn .backward (b_grad )
259+
224260 if f_layer is not None :
225261 with f_layer .get_fp8_context ():
226262 f_input = f_layer .mlp .forward (f_input )
@@ -230,7 +266,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
230266 f_input = f_layer .moe_combine .forward (f_input )
231267 f_input = f_layer .mtp_post_process .forward (f_input )
232268
233- if b_layer is not None :
269+ if b_layer is not None and not b_layer . config . ep_overlap_early_attn_memory_release :
234270 b_grad = b_layer .post_attn .backward (b_grad )
235271 b_grad = b_layer .attn .backward (b_grad )
236272
@@ -372,6 +408,10 @@ def get_layer(self, i):
372408 assert i < self .num_layers ()
373409 return self ._transformer_layers [i ]
374410
411+ def pop_layer (self ):
412+ """Pops the transformer layer in FILO order."""
413+ return self ._transformer_layers .pop ()
414+
375415 def num_layers (self ):
376416 """Gets the number of transformer layers."""
377417 return len (self ._transformer_layers )
@@ -450,29 +490,34 @@ def run(
450490 b_num_layers = b_schedule_plan .num_layers () if b_schedule_plan is not None else 0
451491 overlapped_layers = min (f_num_layers , b_num_layers )
452492
493+ f_layer = b_layer = None
453494 # combined forward and backward pass for overlapped layers
454495 for i in range (overlapped_layers ):
455496 f_layer = f_schedule_plan .get_layer (i )
456- b_layer = b_schedule_plan .get_layer (b_num_layers - 1 - i )
457- torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_num_layers - 1 - i } b" )
458497 if f_layer .layer .config .fine_grained_activation_offloading :
459498 fine_grained_offloading_set_last_layer (i == f_num_layers - 1 )
499+ b_layer = b_schedule_plan .pop_layer ()
500+ torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_schedule_plan .num_layers ()} b" )
460501 f_input , b_grad = TransformerLayerSchedulePlan .run (
461502 f_layer ,
462503 b_layer ,
463504 f_input = f_input ,
464505 b_grad = b_grad ,
465506 is_last_layer_in_bwd = (i == b_num_layers - 1 ),
466507 )
508+ if i < b_num_layers - 1 :
509+ b_layer .release_state ()
467510 torch .cuda .nvtx .range_pop ()
468511
469512 # backward pass for the remaining layers
470513 for i in range (overlapped_layers , b_num_layers ):
471- b_layer = b_schedule_plan .get_layer ( b_num_layers - 1 - i )
472- torch .cuda .nvtx .range_push (f"layer_{ b_num_layers - 1 - i } b" )
514+ b_layer = b_schedule_plan .pop_layer ( )
515+ torch .cuda .nvtx .range_push (f"layer_{ b_schedule_plan . num_layers () } b" )
473516 _ , b_grad = TransformerLayerSchedulePlan .run (
474517 None , b_layer , b_grad = b_grad , is_last_layer_in_bwd = (i == b_num_layers - 1 )
475518 )
519+ if i < b_num_layers - 1 :
520+ b_layer .release_state ()
476521 torch .cuda .nvtx .range_pop ()
477522
478523 # forward pass for the remaining layers
@@ -500,7 +545,9 @@ def run(
500545 # Delay the last attn_dw in backward pass (attn_dw of the first layer)
501546 # for overlapping with the p2p comm
502547 if b_num_layers > 0 :
503- b_schedule_plan .get_layer (0 ).attn .backward_dw ()
548+ assert b_layer is not None
549+ b_layer .attn .backward_dw ()
550+ b_layer .release_state ()
504551
505552 # post process forward
506553 if f_schedule_plan is not None and f_schedule_plan .post_process is not None :
@@ -513,9 +560,7 @@ def run(
513560 f_schedule_plan .wait_current_stream ()
514561 if b_schedule_plan :
515562 b_schedule_plan .wait_current_stream ()
516-
517- # Release reference as early as possible, this helps avoid memory leak.
518- if b_schedule_plan is not None :
563+ # Release reference as early as possible, this helps avoid memory leak.
519564 b_schedule_plan .release_state ()
520565
521566 return f_input
0 commit comments