@@ -432,6 +432,7 @@ def __init__(self, layers, hcg, strategy):
432
432
self .loss_fn_idx = 0
433
433
434
434
self ._compute_loss = True
435
+ self ._return_host_tensor = False
435
436
self .callbacks = pipeline_parallel_callbacks_
436
437
437
438
logger .info (
@@ -1026,13 +1027,18 @@ def train_batch(
1026
1027
1027
1028
return train_loss
1028
1029
1029
- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
1030
+ def eval_batch (
1031
+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
1032
+ ):
1030
1033
self .user_hooks_enabled = False
1031
1034
# reset the virtual pp rank for each run
1032
1035
self .set_virtual_pipeline_rank (0 )
1033
1036
1034
1037
self ._layers .eval ()
1038
+ origin_compute_loss = self ._compute_loss
1035
1039
self ._compute_loss = compute_loss
1040
+ origin_return_host_tensor = self ._return_host_tensor
1041
+ self ._return_host_tensor = return_host_tensor
1036
1042
1037
1043
# store data id for micro_batch
1038
1044
self .micro_batch_id = 0
@@ -1051,7 +1057,6 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1051
1057
startup_steps = min (startup_steps , self .accumulate_steps )
1052
1058
steady_steps = self .accumulate_steps - startup_steps
1053
1059
1054
- input_buffers = []
1055
1060
output_buffers = []
1056
1061
1057
1062
# convert to micro dataset
@@ -1072,8 +1077,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1072
1077
skip_check_meta = True ,
1073
1078
batch_p2p_comm = self ._use_batch_p2p_comm ,
1074
1079
)
1080
+ if not self .is_pipeline_last_stage ():
1081
+ self ._release_output (output_tensor )
1082
+ else :
1083
+ self ._offload_tensors (output_tensor )
1075
1084
1076
- input_buffers .append (input_tensor )
1077
1085
output_buffers .append (output_tensor )
1078
1086
1079
1087
if steady_steps > 0 :
@@ -1094,8 +1102,11 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1094
1102
skip_check_meta = True ,
1095
1103
batch_p2p_comm = self ._use_batch_p2p_comm ,
1096
1104
)
1105
+ if not self .is_pipeline_last_stage ():
1106
+ self ._release_output (output_tensor )
1107
+ else :
1108
+ self ._offload_tensors (output_tensor )
1097
1109
1098
- input_buffers .append (input_tensor )
1099
1110
output_buffers .append (output_tensor )
1100
1111
1101
1112
if not last_iter :
@@ -1105,11 +1116,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
1105
1116
)
1106
1117
1107
1118
if self ._compute_loss :
1108
- self . train_loss = self ._broadcast_final_loss ()
1119
+ train_loss = self ._broadcast_final_loss ()
1109
1120
else :
1110
- self . train_loss = output_buffers
1121
+ train_loss = output_buffers
1111
1122
1112
- return self .train_loss
1123
+ self ._compute_loss = origin_compute_loss
1124
+ self ._return_host_tensor = origin_return_host_tensor
1125
+ return train_loss
1113
1126
1114
1127
def _maybe_loss_compute (
1115
1128
self , output_tensor , micro_dataset , overlap_schedule_mode = False
@@ -1424,6 +1437,23 @@ def _optimizer_step(self):
1424
1437
if self .lr_scheduler :
1425
1438
self .lr_scheduler .step ()
1426
1439
1440
+ def _offload_tensors (self , output_tensor ):
1441
+ if not self ._return_host_tensor :
1442
+ return
1443
+ if isinstance (output_tensor , (tuple , list )):
1444
+ for t in output_tensor :
1445
+ host_tensor = (
1446
+ t .pin_memory () if hasattr (t , "pin_memory" ) else t .cpu ()
1447
+ )
1448
+ host_tensor ._share_buffer_to (t )
1449
+ else :
1450
+ host_tensor = (
1451
+ output_tensor .pin_memory ()
1452
+ if hasattr (output_tensor , "pin_memory" )
1453
+ else output_tensor .cpu ()
1454
+ )
1455
+ host_tensor ._share_buffer_to (output_tensor )
1456
+
1427
1457
def _release_output (self , output ):
1428
1458
def can_free (t ):
1429
1459
return (
@@ -1694,10 +1724,12 @@ def _get_forward_input(self, virtual_pp_rank):
1694
1724
assert hasattr (self , 'output_tensors' )
1695
1725
if not self ._forward_only :
1696
1726
assert hasattr (self , 'output_tensor_grads' )
1697
- assert len (self .input_tensors [virtual_pp_rank ]) == (
1698
- len (self .output_tensors [virtual_pp_rank ]) + 1
1699
- )
1700
- input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1727
+ assert len (self .input_tensors [virtual_pp_rank ]) == (
1728
+ len (self .output_tensors [virtual_pp_rank ]) + 1
1729
+ )
1730
+ input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1731
+ else :
1732
+ input_tensor = self .input_tensors [virtual_pp_rank ].pop ()
1701
1733
return input_tensor
1702
1734
1703
1735
def _store_forward_outputs (
@@ -1712,11 +1744,17 @@ def _store_forward_outputs(
1712
1744
self .schedule_chunks [virtual_pp_rank ].append (schedule_chunk )
1713
1745
if self .is_pipeline_last_stage ():
1714
1746
self .loss_fn_chunks .append (loss_fn_node )
1715
-
1716
- if self ._forward_only :
1747
+ if self ._forward_only :
1748
+ # no need to store tensor for backward
1749
+ if self ._compute_loss :
1750
+ self .output_tensors [virtual_pp_rank ].pop ()
1751
+ # save output_tensors for return value of eval batch
1752
+ else :
1753
+ self ._offload_tensors (output_tensor )
1754
+ else :
1717
1755
# no need to store tensor for backward
1718
- self .input_tensors [ virtual_pp_rank ]. pop ()
1719
- self .output_tensors [virtual_pp_rank ].pop ()
1756
+ if self ._forward_only :
1757
+ self .output_tensors [virtual_pp_rank ].pop ()
1720
1758
1721
1759
def _forward_step_helper (
1722
1760
self ,
@@ -2022,7 +2060,7 @@ def forward_backward_pipeline(
2022
2060
# this strategy is inspired by:
2023
2061
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py
2024
2062
if not compute_loss :
2025
- assert not forward_only , (
2063
+ assert forward_only , (
2026
2064
"compute_loss can only be set to False when forward_only is set to True"
2027
2065
)
2028
2066
@@ -2669,7 +2707,7 @@ def backward_async_comm(
2669
2707
2670
2708
# no steady steps, which only occurs when accumulate_step == num_stage
2671
2709
if not steady_steps :
2672
- output_tensor_grad = p2p .recv_backward (
2710
+ output_tensor_grad = self . _p2p_helper .recv_backward (
2673
2711
self .is_pipeline_last_stage (),
2674
2712
batch_p2p_comm = self ._use_batch_p2p_comm ,
2675
2713
)
@@ -2800,12 +2838,14 @@ def backward_async_comm(
2800
2838
if self ._enable_timer :
2801
2839
self .timers ("broadcast_final_loss" ).start ()
2802
2840
with paddle .amp .auto_cast (enable = False ):
2803
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
2841
+ train_loss_or_logits = self ._broadcast_final_loss (
2842
+ return_micro_batch_loss
2843
+ )
2804
2844
if self ._enable_timer :
2805
2845
self .timers ("broadcast_final_loss" ).stop ()
2806
2846
else :
2807
- # else just return all intermediate output tensor for all micro steps
2808
- train_loss = self .output_tensors
2847
+ # else just return logits without loss func calc
2848
+ train_loss_or_logits = self .output_tensors . pop ()
2809
2849
2810
2850
if self ._clear_every_step_cache :
2811
2851
self ._p2p_helper .clear_meta_cache ()
@@ -2823,7 +2863,7 @@ def backward_async_comm(
2823
2863
), "p2p dynamic_cnt should equal to send_recv_meta_list"
2824
2864
self ._p2p_helper ._dynamic_cnt = 0
2825
2865
2826
- return train_loss
2866
+ return train_loss_or_logits
2827
2867
2828
2868
def train_batch (
2829
2869
self ,
@@ -2854,13 +2894,18 @@ def train_batch(
2854
2894
2855
2895
return train_loss
2856
2896
2857
- def eval_batch (self , data , compute_loss = False , loss_fn_idx = 0 ):
2897
+ def eval_batch (
2898
+ self , data , compute_loss = False , loss_fn_idx = 0 , return_host_tensor = False
2899
+ ):
2858
2900
self .user_hooks_enabled = False
2859
2901
# reset the virtual pp rank for each run
2860
2902
self .set_virtual_pipeline_rank (0 )
2861
2903
2862
2904
self ._layers .eval ()
2905
+ origin_compute_loss = self ._compute_loss
2863
2906
self ._compute_loss = compute_loss
2907
+ origin_return_host_tensor = self ._return_host_tensor
2908
+ self ._return_host_tensor = return_host_tensor
2864
2909
2865
2910
# check loss_fn_idx is valid and loss_fn exists
2866
2911
assert (
@@ -2869,7 +2914,13 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
2869
2914
), f"loss function { loss_fn_idx } should exist to compute loss"
2870
2915
self .loss_fn_idx = loss_fn_idx
2871
2916
2872
- return self .forward_backward_pipeline (data , None , forward_only = True )
2917
+ train_loss_or_logits = self .forward_backward_pipeline (
2918
+ data , None , forward_only = True , compute_loss = compute_loss
2919
+ )
2920
+ self ._init_buffers ()
2921
+ self ._compute_loss = origin_compute_loss
2922
+ self ._return_host_tensor = origin_return_host_tensor
2923
+ return train_loss_or_logits
2873
2924
2874
2925
def get_static_scheduler (self ):
2875
2926
return self .forward_backward_pipeline (
@@ -2959,7 +3010,7 @@ def forward_backward_pipeline(
2959
3010
if self .processed_steps < g_profile_pipeline_details_steps :
2960
3011
get_sync_logger ().info ("start forward_backward_pipeline" )
2961
3012
if not compute_loss :
2962
- assert not forward_only , (
3013
+ assert forward_only , (
2963
3014
"compute_loss can only be set to False when forward_only is set to True"
2964
3015
)
2965
3016
@@ -2977,7 +3028,7 @@ def forward_backward_pipeline(
2977
3028
2978
3029
assert (
2979
3030
self .accumulate_steps == self .num_stages
2980
- or self .accumulate_steps % self .num_stages ! = 0
3031
+ or self .accumulate_steps % self .num_stages = = 0
2981
3032
), (
2982
3033
f"accumulate_steps({ self .accumulate_steps } ) and num_stages({ self .num_stages } ) should be a multiple or accumulate_steps % num_stages == 0"
2983
3034
)
@@ -3108,12 +3159,14 @@ def forward_backward_pipeline(
3108
3159
if self ._enable_timer :
3109
3160
self .timers ("broadcast_final_loss" ).start ()
3110
3161
with paddle .amp .auto_cast (enable = False ):
3111
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3162
+ train_loss_or_logits = self ._broadcast_final_loss (
3163
+ return_micro_batch_loss
3164
+ )
3112
3165
if self ._enable_timer :
3113
3166
self .timers ("broadcast_final_loss" ).stop ()
3114
3167
else :
3115
- # else just return all intermediate output tensor for all micro steps
3116
- train_loss = self .output_tensors
3168
+ # else just return logits without loss func calc
3169
+ train_loss_or_logits = self .output_tensors . pop ()
3117
3170
3118
3171
if self ._clear_every_step_cache :
3119
3172
self ._p2p_helper .clear_meta_cache ()
@@ -3124,7 +3177,7 @@ def forward_backward_pipeline(
3124
3177
get_sync_logger ().info ("end forward_backward_pipeline" )
3125
3178
self .processed_steps += 1
3126
3179
self ._check_user_hooks_status_at_step_end ()
3127
- return train_loss
3180
+ return train_loss_or_logits
3128
3181
3129
3182
3130
3183
class OffloadQueue (queue .Queue ):
@@ -3187,7 +3240,7 @@ def forward_backward_pipeline(
3187
3240
):
3188
3241
self ._reset_user_hooks_status ()
3189
3242
if not compute_loss :
3190
- assert not forward_only , (
3243
+ assert forward_only , (
3191
3244
"compute_loss can only be set to False when forward_only is set to True"
3192
3245
)
3193
3246
assert self ._using_cache , (
@@ -3462,12 +3515,14 @@ def forward_backward_pipeline(
3462
3515
if self ._enable_timer :
3463
3516
self .timers ("broadcast_final_loss" ).start ()
3464
3517
with paddle .amp .auto_cast (enable = False ):
3465
- train_loss = self ._broadcast_final_loss (return_micro_batch_loss )
3518
+ train_loss_or_logits = self ._broadcast_final_loss (
3519
+ return_micro_batch_loss
3520
+ )
3466
3521
if self ._enable_timer :
3467
3522
self .timers ("broadcast_final_loss" ).stop ()
3468
3523
else :
3469
- # else just return all intermediate output tensor for all micro steps
3470
- train_loss = self .output_tensors
3524
+ # else just return logits without loss func calc
3525
+ train_loss_or_logits = self .output_tensors . pop ()
3471
3526
3472
3527
if self ._clear_every_step_cache :
3473
3528
self ._p2p_helper .clear_meta_cache ()
@@ -3478,7 +3533,7 @@ def forward_backward_pipeline(
3478
3533
get_sync_logger ().info ("end forward_backward_pipeline" )
3479
3534
self .processed_steps += 1
3480
3535
self ._check_user_hooks_status_at_step_end ()
3481
- return train_loss
3536
+ return train_loss_or_logits
3482
3537
3483
3538
3484
3539
def tuple_to_dict_helper (input_tensor ):
0 commit comments