Skip to content

Commit 5de7e57

Browse files
authored
Add mp delay_scale_loss function (#7713)
* add mp delay_scale_loss function * remove useless codes
1 parent 1e9b5a8 commit 5de7e57

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

paddlenlp/trainer/trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -917,6 +917,9 @@ def train(
917917
steps_in_epoch <= args.gradient_accumulation_steps
918918
and (step + 1) == steps_in_epoch
919919
):
920+
if self.args.pipeline_parallel_degree <= 1 and self._enable_delay_scale_loss():
921+
tr_loss /= self.args.gradient_accumulation_steps
922+
920923
self.timers and self.timers("forward-backward").stop()
921924
# Maunally collect gradients
922925
# Case 1: Use recompute and dp
@@ -938,7 +941,6 @@ def train(
938941
pipeline_parallel_config = (
939942
set(args.pipeline_parallel_config.split(" ")) if args.pipeline_parallel_degree > 1 else set()
940943
)
941-
enable_delay_scale_loss = "enable_delay_scale_loss" in pipeline_parallel_config
942944
enable_dp_comm_overlap = "enable_dp_comm_overlap" in pipeline_parallel_config
943945
enable_release_grads = "enable_release_grads" in pipeline_parallel_config
944946

@@ -957,7 +959,7 @@ def train(
957959
self.timers and self.timers("all-reduce").stop()
958960
self.timers and self.timers("optimizer-step").start()
959961

960-
if args.pipeline_parallel_degree > 1 and enable_delay_scale_loss:
962+
if self.args.gradient_accumulation_steps > 1 and self._enable_delay_scale_loss():
961963
for p in model._layers.parameters():
962964
with paddle.no_grad():
963965
if hasattr(p, "main_grad") and p.main_grad is not None:
@@ -1901,6 +1903,15 @@ def compute_loss(self, model, inputs, return_outputs=False):
19011903

19021904
return (loss, outputs) if return_outputs else loss
19031905

1906+
def _enable_delay_scale_loss(self):
1907+
key = "enable_delay_scale_loss"
1908+
if self.args.pipeline_parallel_degree > 1:
1909+
return key in self.args.pipeline_parallel_config.split(" ")
1910+
elif self.args.tensor_parallel_degree > 1:
1911+
return key in self.args.tensor_parallel_config.split(" ")
1912+
else:
1913+
return False
1914+
19041915
def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor, Any]]) -> paddle.Tensor:
19051916
"""
19061917
Perform a training step on a batch of inputs.
@@ -1928,7 +1939,7 @@ def training_step(self, model: nn.Layer, inputs: Dict[str, Union[paddle.Tensor,
19281939
with self.autocast_smart_context_manager():
19291940
loss = self.compute_loss(model, inputs)
19301941

1931-
if self.args.gradient_accumulation_steps > 1:
1942+
if self.args.gradient_accumulation_steps > 1 and not self._enable_delay_scale_loss():
19321943
loss = loss / self.args.gradient_accumulation_steps
19331944

19341945
if self.do_grad_scaling:

paddlenlp/trainer/training_args.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ class TrainingArguments:
235235
enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance.
236236
enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.
237237
enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.
238+
enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.
238239
pipeline_parallel_config (`str`, *optional*)(
239240
Some additional config it highly affect the useage of pipeline parallel, we provide some option to config it.
240241
following config is support:
@@ -574,7 +575,8 @@ class TrainingArguments:
574575
"following config is support:\n"
575576
"enable_mp_async_allreduce, it supports all_reduce(dx) overlap with matmul(dw) in ColumnParallelLinear backward when it set True, which can accelerate model parallel performance. \n"
576577
"enable_mp_skip_c_identity, it supports skip c_identity in ColumnParallelLinear and RowParallelLinear. It only works when set mp_async_allreduce is True. It can accelerate model parallel further.\n"
577-
"enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further."
578+
"enable_mp_fused_linear_param_grad_add, it supports fused_linear_param_grad_add in ColumnParallelLinear (cuda >= 11.6). It only works when mp_async_allreduce is true. It can accelerate model parallel further.\n"
579+
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by accumute step. instead of div accumute step on loss directly.\n"
578580
)
579581
},
580582
)
@@ -996,6 +998,7 @@ def __post_init__(self):
996998
"enable_mp_async_allreduce",
997999
"enable_mp_skip_c_identity",
9981000
"enable_mp_fused_linear_param_grad_add",
1001+
"enable_delay_scale_loss",
9991002
]:
10001003
raise ValueError(
10011004
f"Found unknown tensor parallell config {x}, "

0 commit comments

Comments
 (0)