Skip to content

Commit 151b3e5

Browse files
[cherry-pick] [Distributed] fix eval batch & non-compute_loss in pipeline (#74170)
* [Distributed] fix eval batch & non-compute_loss in pipeline (#73479) * [Distributed] fix eval batch && codestyle in PipelineParallel (#73978) --------- Co-authored-by: Tian <[email protected]>
1 parent c1e5656 commit 151b3e5

File tree

1 file changed

+89
-34
lines changed

1 file changed

+89
-34
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 89 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
432432
self.loss_fn_idx = 0
433433

434434
self._compute_loss = True
435+
self._return_host_tensor = False
435436
self.callbacks = pipeline_parallel_callbacks_
436437

437438
logger.info(
@@ -1026,13 +1027,18 @@ def train_batch(
10261027

10271028
return train_loss
10281029

1029-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1030+
def eval_batch(
1031+
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
1032+
):
10301033
self.user_hooks_enabled = False
10311034
# reset the virtual pp rank for each run
10321035
self.set_virtual_pipeline_rank(0)
10331036

10341037
self._layers.eval()
1038+
origin_compute_loss = self._compute_loss
10351039
self._compute_loss = compute_loss
1040+
origin_return_host_tensor = self._return_host_tensor
1041+
self._return_host_tensor = return_host_tensor
10361042

10371043
# store data id for micro_batch
10381044
self.micro_batch_id = 0
@@ -1051,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10511057
startup_steps = min(startup_steps, self.accumulate_steps)
10521058
steady_steps = self.accumulate_steps - startup_steps
10531059

1054-
input_buffers = []
10551060
output_buffers = []
10561061

10571062
# convert to micro dataset
@@ -1072,8 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10721077
skip_check_meta=True,
10731078
batch_p2p_comm=self._use_batch_p2p_comm,
10741079
)
1080+
if not self.is_pipeline_last_stage():
1081+
self._release_output(output_tensor)
1082+
else:
1083+
self._offload_tensors(output_tensor)
10751084

1076-
input_buffers.append(input_tensor)
10771085
output_buffers.append(output_tensor)
10781086

10791087
if steady_steps > 0:
@@ -1094,8 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
10941102
skip_check_meta=True,
10951103
batch_p2p_comm=self._use_batch_p2p_comm,
10961104
)
1105+
if not self.is_pipeline_last_stage():
1106+
self._release_output(output_tensor)
1107+
else:
1108+
self._offload_tensors(output_tensor)
10971109

1098-
input_buffers.append(input_tensor)
10991110
output_buffers.append(output_tensor)
11001111

11011112
if not last_iter:
@@ -1105,11 +1116,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
11051116
)
11061117

11071118
if self._compute_loss:
1108-
self.train_loss = self._broadcast_final_loss()
1119+
train_loss = self._broadcast_final_loss()
11091120
else:
1110-
self.train_loss = output_buffers
1121+
train_loss = output_buffers
11111122

1112-
return self.train_loss
1123+
self._compute_loss = origin_compute_loss
1124+
self._return_host_tensor = origin_return_host_tensor
1125+
return train_loss
11131126

11141127
def _maybe_loss_compute(
11151128
self, output_tensor, micro_dataset, overlap_schedule_mode=False
@@ -1424,6 +1437,23 @@ def _optimizer_step(self):
14241437
if self.lr_scheduler:
14251438
self.lr_scheduler.step()
14261439

1440+
def _offload_tensors(self, output_tensor):
1441+
if not self._return_host_tensor:
1442+
return
1443+
if isinstance(output_tensor, (tuple, list)):
1444+
for t in output_tensor:
1445+
host_tensor = (
1446+
t.pin_memory() if hasattr(t, "pin_memory") else t.cpu()
1447+
)
1448+
host_tensor._share_buffer_to(t)
1449+
else:
1450+
host_tensor = (
1451+
output_tensor.pin_memory()
1452+
if hasattr(output_tensor, "pin_memory")
1453+
else output_tensor.cpu()
1454+
)
1455+
host_tensor._share_buffer_to(output_tensor)
1456+
14271457
def _release_output(self, output):
14281458
def can_free(t):
14291459
return (
@@ -1694,10 +1724,12 @@ def _get_forward_input(self, virtual_pp_rank):
16941724
assert hasattr(self, 'output_tensors')
16951725
if not self._forward_only:
16961726
assert hasattr(self, 'output_tensor_grads')
1697-
assert len(self.input_tensors[virtual_pp_rank]) == (
1698-
len(self.output_tensors[virtual_pp_rank]) + 1
1699-
)
1700-
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1727+
assert len(self.input_tensors[virtual_pp_rank]) == (
1728+
len(self.output_tensors[virtual_pp_rank]) + 1
1729+
)
1730+
input_tensor = self.input_tensors[virtual_pp_rank][-1]
1731+
else:
1732+
input_tensor = self.input_tensors[virtual_pp_rank].pop()
17011733
return input_tensor
17021734

17031735
def _store_forward_outputs(
@@ -1712,11 +1744,17 @@ def _store_forward_outputs(
17121744
self.schedule_chunks[virtual_pp_rank].append(schedule_chunk)
17131745
if self.is_pipeline_last_stage():
17141746
self.loss_fn_chunks.append(loss_fn_node)
1715-
1716-
if self._forward_only:
1747+
if self._forward_only:
1748+
# no need to store tensor for backward
1749+
if self._compute_loss:
1750+
self.output_tensors[virtual_pp_rank].pop()
1751+
# save output_tensors for return value of eval batch
1752+
else:
1753+
self._offload_tensors(output_tensor)
1754+
else:
17171755
# no need to store tensor for backward
1718-
self.input_tensors[virtual_pp_rank].pop()
1719-
self.output_tensors[virtual_pp_rank].pop()
1756+
if self._forward_only:
1757+
self.output_tensors[virtual_pp_rank].pop()
17201758

17211759
def _forward_step_helper(
17221760
self,
@@ -2022,7 +2060,7 @@ def forward_backward_pipeline(
20222060
# this strategy is inspired by:
20232061
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
20242062
if not compute_loss:
2025-
assert not forward_only, (
2063+
assert forward_only, (
20262064
"compute_loss can only be set to False when forward_only is set to True"
20272065
)
20282066

@@ -2669,7 +2707,7 @@ def backward_async_comm(
26692707

26702708
# no steady steps, which only occurs when accumulate_step == num_stage
26712709
if not steady_steps:
2672-
output_tensor_grad = p2p.recv_backward(
2710+
output_tensor_grad = self._p2p_helper.recv_backward(
26732711
self.is_pipeline_last_stage(),
26742712
batch_p2p_comm=self._use_batch_p2p_comm,
26752713
)
@@ -2800,12 +2838,14 @@ def backward_async_comm(
28002838
if self._enable_timer:
28012839
self.timers("broadcast_final_loss").start()
28022840
with paddle.amp.auto_cast(enable=False):
2803-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
2841+
train_loss_or_logits = self._broadcast_final_loss(
2842+
return_micro_batch_loss
2843+
)
28042844
if self._enable_timer:
28052845
self.timers("broadcast_final_loss").stop()
28062846
else:
2807-
# else just return all intermediate output tensor for all micro steps
2808-
train_loss = self.output_tensors
2847+
# else just return logits without loss func calc
2848+
train_loss_or_logits = self.output_tensors.pop()
28092849

28102850
if self._clear_every_step_cache:
28112851
self._p2p_helper.clear_meta_cache()
@@ -2823,7 +2863,7 @@ def backward_async_comm(
28232863
), "p2p dynamic_cnt should equal to send_recv_meta_list"
28242864
self._p2p_helper._dynamic_cnt = 0
28252865

2826-
return train_loss
2866+
return train_loss_or_logits
28272867

28282868
def train_batch(
28292869
self,
@@ -2854,13 +2894,18 @@ def train_batch(
28542894

28552895
return train_loss
28562896

2857-
def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
2897+
def eval_batch(
2898+
self, data, compute_loss=False, loss_fn_idx=0, return_host_tensor=False
2899+
):
28582900
self.user_hooks_enabled = False
28592901
# reset the virtual pp rank for each run
28602902
self.set_virtual_pipeline_rank(0)
28612903

28622904
self._layers.eval()
2905+
origin_compute_loss = self._compute_loss
28632906
self._compute_loss = compute_loss
2907+
origin_return_host_tensor = self._return_host_tensor
2908+
self._return_host_tensor = return_host_tensor
28642909

28652910
# check loss_fn_idx is valid and loss_fn exists
28662911
assert (
@@ -2869,7 +2914,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
28692914
), f"loss function {loss_fn_idx} should exist to compute loss"
28702915
self.loss_fn_idx = loss_fn_idx
28712916

2872-
return self.forward_backward_pipeline(data, None, forward_only=True)
2917+
train_loss_or_logits = self.forward_backward_pipeline(
2918+
data, None, forward_only=True, compute_loss=compute_loss
2919+
)
2920+
self._init_buffers()
2921+
self._compute_loss = origin_compute_loss
2922+
self._return_host_tensor = origin_return_host_tensor
2923+
return train_loss_or_logits
28732924

28742925
def get_static_scheduler(self):
28752926
return self.forward_backward_pipeline(
@@ -2959,7 +3010,7 @@ def forward_backward_pipeline(
29593010
if self.processed_steps < g_profile_pipeline_details_steps:
29603011
get_sync_logger().info("start forward_backward_pipeline")
29613012
if not compute_loss:
2962-
assert not forward_only, (
3013+
assert forward_only, (
29633014
"compute_loss can only be set to False when forward_only is set to True"
29643015
)
29653016

@@ -2977,7 +3028,7 @@ def forward_backward_pipeline(
29773028

29783029
assert (
29793030
self.accumulate_steps == self.num_stages
2980-
or self.accumulate_steps % self.num_stages != 0
3031+
or self.accumulate_steps % self.num_stages == 0
29813032
), (
29823033
f"accumulate_steps({self.accumulate_steps}) and num_stages({self.num_stages}) should be a multiple or accumulate_steps % num_stages == 0"
29833034
)
@@ -3108,12 +3159,14 @@ def forward_backward_pipeline(
31083159
if self._enable_timer:
31093160
self.timers("broadcast_final_loss").start()
31103161
with paddle.amp.auto_cast(enable=False):
3111-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3162+
train_loss_or_logits = self._broadcast_final_loss(
3163+
return_micro_batch_loss
3164+
)
31123165
if self._enable_timer:
31133166
self.timers("broadcast_final_loss").stop()
31143167
else:
3115-
# else just return all intermediate output tensor for all micro steps
3116-
train_loss = self.output_tensors
3168+
# else just return logits without loss func calc
3169+
train_loss_or_logits = self.output_tensors.pop()
31173170

31183171
if self._clear_every_step_cache:
31193172
self._p2p_helper.clear_meta_cache()
@@ -3124,7 +3177,7 @@ def forward_backward_pipeline(
31243177
get_sync_logger().info("end forward_backward_pipeline")
31253178
self.processed_steps += 1
31263179
self._check_user_hooks_status_at_step_end()
3127-
return train_loss
3180+
return train_loss_or_logits
31283181

31293182

31303183
class OffloadQueue(queue.Queue):
@@ -3187,7 +3240,7 @@ def forward_backward_pipeline(
31873240
):
31883241
self._reset_user_hooks_status()
31893242
if not compute_loss:
3190-
assert not forward_only, (
3243+
assert forward_only, (
31913244
"compute_loss can only be set to False when forward_only is set to True"
31923245
)
31933246
assert self._using_cache, (
@@ -3462,12 +3515,14 @@ def forward_backward_pipeline(
34623515
if self._enable_timer:
34633516
self.timers("broadcast_final_loss").start()
34643517
with paddle.amp.auto_cast(enable=False):
3465-
train_loss = self._broadcast_final_loss(return_micro_batch_loss)
3518+
train_loss_or_logits = self._broadcast_final_loss(
3519+
return_micro_batch_loss
3520+
)
34663521
if self._enable_timer:
34673522
self.timers("broadcast_final_loss").stop()
34683523
else:
3469-
# else just return all intermediate output tensor for all micro steps
3470-
train_loss = self.output_tensors
3524+
# else just return logits without loss func calc
3525+
train_loss_or_logits = self.output_tensors.pop()
34713526

34723527
if self._clear_every_step_cache:
34733528
self._p2p_helper.clear_meta_cache()
@@ -3478,7 +3533,7 @@ def forward_backward_pipeline(
34783533
get_sync_logger().info("end forward_backward_pipeline")
34793534
self.processed_steps += 1
34803535
self._check_user_hooks_status_at_step_end()
3481-
return train_loss
3536+
return train_loss_or_logits
34823537

34833538

34843539
def tuple_to_dict_helper(input_tensor):

0 commit comments

Comments
 (0)