Skip to content

Commit 4ab7432

Browse files
authored
test=document_fix (#73546)
1 parent 06155ab commit 4ab7432

File tree

1 file changed

+4
-19
lines changed

1 file changed

+4
-19
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1680,11 +1680,7 @@ def _store_forward_outputs(
16801680
self.output_tensors[virtual_pp_rank].pop()
16811681

16821682
def _forward_step_helper(
1683-
self,
1684-
micro_dataset,
1685-
micro_step,
1686-
overlap_schedule_mode=False,
1687-
release_input=False,
1683+
self, micro_dataset, micro_step, overlap_schedule_mode=False
16881684
):
16891685
virtual_pp_rank = self._get_virtual_pp_rank(micro_step, forward=True)
16901686
self.set_virtual_pipeline_rank(virtual_pp_rank)
@@ -1702,10 +1698,6 @@ def _forward_step_helper(
17021698
self._store_forward_outputs(
17031699
virtual_pp_rank, output_tensor, schedule_chunk, loss_fn_node
17041700
)
1705-
1706-
if release_input:
1707-
return output_tensor, input_tensor
1708-
17091701
return output_tensor
17101702

17111703
def _overlap_comm_grads(self):
@@ -3210,9 +3202,7 @@ def forward_backward_pipeline(
32103202
# to simplify the code logic of stage 1F1B.
32113203
for micro_step in range(startup_steps):
32123204
self._record_stamp("F", micro_step, '"B"', forward=True)
3213-
output_tensor, input_t = self._forward_step_helper(
3214-
micro_dataset, micro_step, release_input=True
3215-
)
3205+
output_tensor = self._forward_step_helper(micro_dataset, micro_step)
32163206
self._record_stamp("F", micro_step, '"E"', forward=True)
32173207
next_forward_virtual_pp_rank = self._get_virtual_pp_rank(
32183208
micro_step + 1, forward=True
@@ -3243,7 +3233,6 @@ def forward_backward_pipeline(
32433233
input_tensor
32443234
)
32453235
self._release_output(output_tensor)
3246-
self._release_output(input_t)
32473236

32483237
if self.is_pipeline_first_stage(ignore_virtual=True):
32493238
assert (
@@ -3258,8 +3247,8 @@ def forward_backward_pipeline(
32583247
backward_micro_step_id = micro_step
32593248

32603249
self._record_stamp("F", forward_micro_step_id, '"B"', forward=True)
3261-
output_tensor, input_t = self._forward_step_helper(
3262-
micro_dataset, forward_micro_step_id, release_input=True
3250+
output_tensor = self._forward_step_helper(
3251+
micro_dataset, forward_micro_step_id
32633252
)
32643253
self._record_stamp("F", forward_micro_step_id, '"E"', forward=True)
32653254

@@ -3270,10 +3259,6 @@ def forward_backward_pipeline(
32703259
self.is_pipeline_last_stage(ignore_virtual=True),
32713260
batch_p2p_comm=self._use_batch_p2p_comm,
32723261
)
3273-
3274-
self._release_output(output_tensor)
3275-
self._release_output(input_t)
3276-
32773262
output_tensor_grad = self._p2p_helper.recv_backward(
32783263
self.is_pipeline_last_stage(ignore_virtual=True),
32793264
batch_p2p_comm=self._use_batch_p2p_comm,

0 commit comments

Comments
 (0)