Skip to content

Commit 23b4a3b

Browse files
authored
release input of pipeline in FthenB (#73398) (#73508)
1 parent f2fdcc2 commit 23b4a3b

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,7 +1680,11 @@ def _store_forward_outputs(
16801680
self.output_tensors[virtual_pp_rank].pop()
16811681

16821682
def _forward_step_helper(
1683-
self, micro_dataset, micro_step, overlap_schedule_mode=False
1683+
self,
1684+
micro_dataset,
1685+
micro_step,
1686+
overlap_schedule_mode=False,
1687+
release_input=False,
16841688
):
16851689
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
16861690
self.set_virtual_pipeline_rank(virtual_pp_rank)
@@ -1698,6 +1702,10 @@ def _forward_step_helper(
16981702
self._store_forward_outputs(
16991703
virtual_pp_rank, output_tensor, schedule_chunk, loss_fn_node
17001704
)
1705+
1706+
if release_input:
1707+
return output_tensor, input_tensor
1708+
17011709
return output_tensor
17021710

17031711
def _overlap_comm_grads(self):
@@ -3202,7 +3210,9 @@ def forward_backward_pipeline(
32023210
# to simplify the code logic of stage 1F1B.
32033211
for micro_step in range(startup_steps):
32043212
self._record_stamp("F", micro_step, '"B"', forward=True)
3205-
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
3213+
output_tensor, input_t = self._forward_step_helper(
3214+
micro_dataset, micro_step, release_input=True
3215+
)
32063216
self._record_stamp("F", micro_step, '"E"', forward=True)
32073217
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
32083218
micro_step + 1, forward=True
@@ -3233,6 +3243,7 @@ def forward_backward_pipeline(
32333243
input_tensor
32343244
)
32353245
self._release_output(output_tensor)
3246+
self._release_output(input_t)
32363247

32373248
if self.is_pipeline_first_stage(ignore_virtual=True):
32383249
assert (
@@ -3247,8 +3258,8 @@ def forward_backward_pipeline(
32473258
backward_micro_step_id = micro_step
32483259

32493260
self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
3250-
output_tensor = self._forward_step_helper(
3251-
micro_dataset, forward_micro_step_id
3261+
output_tensor, input_t = self._forward_step_helper(
3262+
micro_dataset, forward_micro_step_id, release_input=True
32523263
)
32533264
self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
32543265

@@ -3259,6 +3270,10 @@ def forward_backward_pipeline(
32593270
self.is_pipeline_last_stage(ignore_virtual=True),
32603271
batch_p2p_comm=self._use_batch_p2p_comm,
32613272
)
3273+
3274+
self._release_output(output_tensor)
3275+
self._release_output(input_t)
3276+
32623277
output_tensor_grad = self._p2p_helper.recv_backward(
32633278
self.is_pipeline_last_stage(ignore_virtual=True),
32643279
batch_p2p_comm=self._use_batch_p2p_comm,

0 commit comments

Comments
 (0)