@@ -1680,11 +1680,7 @@ def _store_forward_outputs(
1680
1680
self .output_tensors [virtual_pp_rank ].pop ()
1681
1681
1682
1682
def _forward_step_helper (
1683
- self ,
1684
- micro_dataset ,
1685
- micro_step ,
1686
- overlap_schedule_mode = False ,
1687
- release_input = False ,
1683
+ self , micro_dataset , micro_step , overlap_schedule_mode = False
1688
1684
):
1689
1685
virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = True )
1690
1686
self .set_virtual_pipeline_rank (virtual_pp_rank )
@@ -1702,10 +1698,6 @@ def _forward_step_helper(
1702
1698
self ._store_forward_outputs (
1703
1699
virtual_pp_rank , output_tensor , schedule_chunk , loss_fn_node
1704
1700
)
1705
-
1706
- if release_input :
1707
- return output_tensor , input_tensor
1708
-
1709
1701
return output_tensor
1710
1702
1711
1703
def _overlap_comm_grads (self ):
@@ -3210,9 +3202,7 @@ def forward_backward_pipeline(
3210
3202
# to simplify the code logic of stage 1F1B.
3211
3203
for micro_step in range (startup_steps ):
3212
3204
self ._record_stamp ("F" , micro_step , '"B"' , forward = True )
3213
- output_tensor , input_t = self ._forward_step_helper (
3214
- micro_dataset , micro_step , release_input = True
3215
- )
3205
+ output_tensor = self ._forward_step_helper (micro_dataset , micro_step )
3216
3206
self ._record_stamp ("F" , micro_step , '"E"' , forward = True )
3217
3207
next_forward_virtual_pp_rank = self ._get_virtual_pp_rank (
3218
3208
micro_step + 1 , forward = True
@@ -3243,7 +3233,6 @@ def forward_backward_pipeline(
3243
3233
input_tensor
3244
3234
)
3245
3235
self ._release_output (output_tensor )
3246
- self ._release_output (input_t )
3247
3236
3248
3237
if self .is_pipeline_first_stage (ignore_virtual = True ):
3249
3238
assert (
@@ -3258,8 +3247,8 @@ def forward_backward_pipeline(
3258
3247
backward_micro_step_id = micro_step
3259
3248
3260
3249
self ._record_stamp ("F" , forward_micro_step_id , '"B"' , forward = True )
3261
- output_tensor , input_t = self ._forward_step_helper (
3262
- micro_dataset , forward_micro_step_id , release_input = True
3250
+ output_tensor = self ._forward_step_helper (
3251
+ micro_dataset , forward_micro_step_id
3263
3252
)
3264
3253
self ._record_stamp ("F" , forward_micro_step_id , '"E"' , forward = True )
3265
3254
@@ -3270,10 +3259,6 @@ def forward_backward_pipeline(
3270
3259
self .is_pipeline_last_stage (ignore_virtual = True ),
3271
3260
batch_p2p_comm = self ._use_batch_p2p_comm ,
3272
3261
)
3273
-
3274
- self ._release_output (output_tensor )
3275
- self ._release_output (input_t )
3276
-
3277
3262
output_tensor_grad = self ._p2p_helper .recv_backward (
3278
3263
self .is_pipeline_last_stage (ignore_virtual = True ),
3279
3264
batch_p2p_comm = self ._use_batch_p2p_comm ,
0 commit comments