@@ -1680,7 +1680,11 @@ def _store_forward_outputs(
1680
1680
self .output_tensors [virtual_pp_rank ].pop ()
1681
1681
1682
1682
def _forward_step_helper (
1683
- self , micro_dataset , micro_step , overlap_schedule_mode = False
1683
+ self ,
1684
+ micro_dataset ,
1685
+ micro_step ,
1686
+ overlap_schedule_mode = False ,
1687
+ release_input = False ,
1684
1688
):
1685
1689
virtual_pp_rank = self ._get_virtual_pp_rank (micro_step , forward = True )
1686
1690
self .set_virtual_pipeline_rank (virtual_pp_rank )
@@ -1698,6 +1702,10 @@ def _forward_step_helper(
1698
1702
self ._store_forward_outputs (
1699
1703
virtual_pp_rank , output_tensor , schedule_chunk , loss_fn_node
1700
1704
)
1705
+
1706
+ if release_input :
1707
+ return output_tensor , input_tensor
1708
+
1701
1709
return output_tensor
1702
1710
1703
1711
def _overlap_comm_grads (self ):
@@ -3202,7 +3210,9 @@ def forward_backward_pipeline(
3202
3210
# to simplify the code logic of stage 1F1B.
3203
3211
for micro_step in range (startup_steps ):
3204
3212
self ._record_stamp ("F" , micro_step , '"B"' , forward = True )
3205
- output_tensor = self ._forward_step_helper (micro_dataset , micro_step )
3213
+ output_tensor , input_t = self ._forward_step_helper (
3214
+ micro_dataset , micro_step , release_input = True
3215
+ )
3206
3216
self ._record_stamp ("F" , micro_step , '"E"' , forward = True )
3207
3217
next_forward_virtual_pp_rank = self ._get_virtual_pp_rank (
3208
3218
micro_step + 1 , forward = True
@@ -3233,6 +3243,7 @@ def forward_backward_pipeline(
3233
3243
input_tensor
3234
3244
)
3235
3245
self ._release_output (output_tensor )
3246
+ self ._release_output (input_t )
3236
3247
3237
3248
if self .is_pipeline_first_stage (ignore_virtual = True ):
3238
3249
assert (
@@ -3247,8 +3258,8 @@ def forward_backward_pipeline(
3247
3258
backward_micro_step_id = micro_step
3248
3259
3249
3260
self ._record_stamp ("F" , forward_micro_step_id , '"B"' , forward = True )
3250
- output_tensor = self ._forward_step_helper (
3251
- micro_dataset , forward_micro_step_id
3261
+ output_tensor , input_t = self ._forward_step_helper (
3262
+ micro_dataset , forward_micro_step_id , release_input = True
3252
3263
)
3253
3264
self ._record_stamp ("F" , forward_micro_step_id , '"E"' , forward = True )
3254
3265
@@ -3259,6 +3270,10 @@ def forward_backward_pipeline(
3259
3270
self .is_pipeline_last_stage (ignore_virtual = True ),
3260
3271
batch_p2p_comm = self ._use_batch_p2p_comm ,
3261
3272
)
3273
+
3274
+ self ._release_output (output_tensor )
3275
+ self ._release_output (input_t )
3276
+
3262
3277
output_tensor_grad = self ._p2p_helper .recv_backward (
3263
3278
self .is_pipeline_last_stage (ignore_virtual = True ),
3264
3279
batch_p2p_comm = self ._use_batch_p2p_comm ,
0 commit comments