@@ -77,7 +77,6 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
7777 """
7878 from megatron .core .models .gpt .fine_grained_callables import TransformerLayerState
7979
80- self .config = layer .config
8180 self .layer_state = TransformerLayerState ()
8281 self .chunk_state = chunk_state
8382 self .layer = layer
@@ -88,32 +87,6 @@ def __init__(self, layer, event, chunk_state, comp_stream, comm_stream, extra_ar
8887 # get callable nodes for transformer/mtp layer
8988 self ._build_callable_nodes (event , comp_stream , comm_stream , extra_args )
9089
91- def release_state (self ):
92- """Release reference, this helps avoid memory leak."""
93- if hasattr (self , 'attn' ) and self .attn is not None :
94- del self .attn
95- self .attn = None
96- if hasattr (self , 'post_attn' ) and self .post_attn is not None :
97- del self .post_attn
98- self .post_attn = None
99- if hasattr (self , 'moe_dispatch' ) and self .moe_dispatch is not None :
100- del self .moe_dispatch
101- self .moe_dispatch = None
102- if hasattr (self , 'mlp' ) and self .mlp is not None :
103- del self .mlp
104- self .mlp = None
105- if hasattr (self , 'moe_combine' ) and self .moe_combine is not None :
106- del self .moe_combine
107- self .moe_combine = None
108- if hasattr (self , 'mtp_post_process' ) and self .mtp_post_process is not None :
109- del self .mtp_post_process
110- self .mtp_post_process = None
111- if hasattr (self , 'layer_state' ) and self .layer_state is not None :
112- del self .layer_state
113- self .layer_state = None
114- if hasattr (self , 'layer' ):
115- del self .layer
116-
11790 def _build_callable_nodes (self , event , comp_stream , comm_stream , extra_args ):
11891 """
11992 Builds the callable nodes for the transformer/mtp layer:
@@ -141,12 +114,7 @@ def _build_callable_nodes(self, event, comp_stream, comm_stream, extra_args):
141114 self .layer .config .moe_token_dispatcher_type == "flex"
142115 and self .layer .config .moe_flex_dispatcher_backend == "deepep"
143116 )
144- enable_hybridep = (
145- self .layer .config .moe_token_dispatcher_type == "flex"
146- and self .layer .config .moe_flex_dispatcher_backend == "hybridep"
147- )
148117 extra_args ["enable_deepep" ] = enable_deepep
149- extra_args ["enable_hybridep" ] = enable_hybridep
150118 extra_args ["is_moe" ] = is_moe
151119 extra_args ["delay_wgrad_compute" ] = self .layer .config .delay_wgrad_compute
152120 extra_args ["is_mtp" ] = is_mtp
@@ -253,10 +221,6 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
253221 b_layer .mlp .backward_dw ()
254222 b_grad = b_layer .moe_dispatch .backward (b_grad )
255223
256- if b_layer is not None and b_layer .config .ep_overlap_early_attn_memory_release :
257- b_grad = b_layer .post_attn .backward (b_grad )
258- b_grad = b_layer .attn .backward (b_grad )
259-
260224 if f_layer is not None :
261225 with f_layer .get_fp8_context ():
262226 f_input = f_layer .mlp .forward (f_input )
@@ -266,7 +230,7 @@ def run(f_layer, b_layer, f_input=None, b_grad=None, is_last_layer_in_bwd=False)
266230 f_input = f_layer .moe_combine .forward (f_input )
267231 f_input = f_layer .mtp_post_process .forward (f_input )
268232
269- if b_layer is not None and not b_layer . config . ep_overlap_early_attn_memory_release :
233+ if b_layer is not None :
270234 b_grad = b_layer .post_attn .backward (b_grad )
271235 b_grad = b_layer .attn .backward (b_grad )
272236
@@ -408,10 +372,6 @@ def get_layer(self, i):
408372 assert i < self .num_layers ()
409373 return self ._transformer_layers [i ]
410374
411- def pop_layer (self ):
412- """Pops the transformer layer in FILO order."""
413- return self ._transformer_layers .pop ()
414-
415375 def num_layers (self ):
416376 """Gets the number of transformer layers."""
417377 return len (self ._transformer_layers )
@@ -490,34 +450,29 @@ def run(
490450 b_num_layers = b_schedule_plan .num_layers () if b_schedule_plan is not None else 0
491451 overlapped_layers = min (f_num_layers , b_num_layers )
492452
493- f_layer = b_layer = None
494453 # combined forward and backward pass for overlapped layers
495454 for i in range (overlapped_layers ):
496455 f_layer = f_schedule_plan .get_layer (i )
456+ b_layer = b_schedule_plan .get_layer (b_num_layers - 1 - i )
457+ torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_num_layers - 1 - i } b" )
497458 if f_layer .layer .config .fine_grained_activation_offloading :
498459 fine_grained_offloading_set_last_layer (i == f_num_layers - 1 )
499- b_layer = b_schedule_plan .pop_layer ()
500- torch .cuda .nvtx .range_push (f"layer_{ i } f-layer_{ b_schedule_plan .num_layers ()} b" )
501460 f_input , b_grad = TransformerLayerSchedulePlan .run (
502461 f_layer ,
503462 b_layer ,
504463 f_input = f_input ,
505464 b_grad = b_grad ,
506465 is_last_layer_in_bwd = (i == b_num_layers - 1 ),
507466 )
508- if i < b_num_layers - 1 :
509- b_layer .release_state ()
510467 torch .cuda .nvtx .range_pop ()
511468
512469 # backward pass for the remaining layers
513470 for i in range (overlapped_layers , b_num_layers ):
514- b_layer = b_schedule_plan .pop_layer ( )
515- torch .cuda .nvtx .range_push (f"layer_{ b_schedule_plan . num_layers () } b" )
471+ b_layer = b_schedule_plan .get_layer ( b_num_layers - 1 - i )
472+ torch .cuda .nvtx .range_push (f"layer_{ b_num_layers - 1 - i } b" )
516473 _ , b_grad = TransformerLayerSchedulePlan .run (
517474 None , b_layer , b_grad = b_grad , is_last_layer_in_bwd = (i == b_num_layers - 1 )
518475 )
519- if i < b_num_layers - 1 :
520- b_layer .release_state ()
521476 torch .cuda .nvtx .range_pop ()
522477
523478 # forward pass for the remaining layers
@@ -545,9 +500,7 @@ def run(
545500 # Delay the last attn_dw in backward pass (attn_dw of the first layer)
546501 # for overlapping with the p2p comm
547502 if b_num_layers > 0 :
548- assert b_layer is not None
549- b_layer .attn .backward_dw ()
550- b_layer .release_state ()
503+ b_schedule_plan .get_layer (0 ).attn .backward_dw ()
551504
552505 # post process forward
553506 if f_schedule_plan is not None and f_schedule_plan .post_process is not None :
@@ -560,7 +513,9 @@ def run(
560513 f_schedule_plan .wait_current_stream ()
561514 if b_schedule_plan :
562515 b_schedule_plan .wait_current_stream ()
563- # Release reference as early as possible, this helps avoid memory leak.
516+
517+ # Release reference as early as possible, this helps avoid memory leak.
518+ if b_schedule_plan is not None :
564519 b_schedule_plan .release_state ()
565520
566521 return f_input
0 commit comments