Skip to content

Commit 89daaa3

Browse files
authored
[Trainer] Fix sharding overlap bug (#8334)
1 parent 3105c18 commit 89daaa3

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

paddlenlp/trainer/training_args.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,6 +1020,11 @@ def __post_init__(self):
10201020
enable_dp_comm_overlap and enable_sharding_comm_overlap
10211021
), "dp_comm_overlap and sharding_comm_overlap cannot be enabled at the same time"
10221022

1023+
if enable_sharding_comm_overlap and not self.amp_master_grad:
1024+
raise ValueError(
1025+
"If `enable_sharding_comm_overlap` in pipeline_parallel_configs, `amp_master_grad` must be True."
1026+
)
1027+
10231028
dygraph_pp_configs = {
10241029
"delay_scale_loss": True if "enable_delay_scale_loss" in pipeline_parallel_config else False,
10251030
"dp_comm_overlap": enable_dp_comm_overlap,

0 commit comments

Comments
 (0)