Skip to content

Commit b976b74

Browse files
authored
[NPU] Add flatten_param_grads for Trainer to improve NPU performance (#4426)
1 parent 06de433 commit b976b74

File tree

4 files changed

+147
-1
lines changed

4 files changed

+147
-1
lines changed

docs/trainer.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,10 +548,15 @@ Trainer 是一个简单,但功能完整的 Paddle训练和评估模块,并
548548
是否从断点重启恢复训练,(可选,默认为 None)
549549
The path to a folder with a valid checkpoint for your
550550
model. (default: None)
551-
551+
552552
--skip_memory_metrics
553553
是否跳过内存profiler检测。(可选,默认为True,跳过)
554554
Whether or not to skip adding of memory profiler reports
555555
to metrics.(default:True)
556556

557+
--flatten_param_grads
558+
是否在优化器中使用flatten_param_grads策略,该策略将素有参数摊平后输入Optimizer更新。目前该策略仅在NPU设备上生效。(可选,默认为False
559+
Whether use flatten_param_grads method in optimizer,
560+
only used on NPU devices.(default:False)
561+
557562
```
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# Copyright 2020-present the HuggingFace Inc. team.
2+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import types
17+
18+
import numpy as np
19+
import paddle
20+
from paddle.fluid.layer_helper import LayerHelper
21+
22+
from ...utils.log import logger
23+
24+
25+
def npu_accelerate_plugin(optimizer):
26+
"""npu_accelerate_plugin uses the flatten_param_grads method to speed up the performance of the model on NPU devices.
27+
flatten_param_grads method will be added to `step` function of optimizer.
28+
29+
Args:
30+
optimizer (`paddle.optimizer.Optimizer`):
31+
The Optimizer whose `step` method will be modified.
32+
"""
33+
optimizer.step = types.MethodType(_optimizer_step_with_flatten_param_grads, optimizer)
34+
35+
36+
def _optimizer_step_with_flatten_param_grads(optimizer):
37+
if not isinstance(optimizer._param_groups[0], dict):
38+
params_grads = []
39+
for param in optimizer._param_groups:
40+
if param.stop_gradient:
41+
continue
42+
if param._grad_ivar() is not None:
43+
grad_var = param._grad_ivar()
44+
params_grads.append((param, grad_var))
45+
46+
# currently, only support ClipGradByGlobalNorm and without regularization.
47+
if isinstance(params_grads, list) and optimizer.regularization is None:
48+
if optimizer._grad_clip is None or isinstance(optimizer._grad_clip, paddle.nn.ClipGradByGlobalNorm):
49+
params_grads = _flatten_param_grads(optimizer, params_grads)
50+
51+
optimizer._apply_optimize(
52+
loss=None,
53+
startup_program=None,
54+
params_grads=params_grads,
55+
param_group_idx=0,
56+
)
57+
else:
58+
raise RuntimeError("flatten_param_grads is not supported when _param_groups[0] is dict.")
59+
60+
61+
def _flatten_param_grads(optimizer, params_grads):
62+
optimizer.helper = LayerHelper(optimizer.__class__.__name__)
63+
need_flatten_params = []
64+
need_flatten_grads = []
65+
for p, g in params_grads:
66+
if g is None:
67+
continue
68+
g.persistable = True
69+
if getattr(p, "need_clip", True) is False or getattr(p, "regularizer", None) is not None:
70+
logger.warning(
71+
f"flatten_param_grads=True will be discarded since paramter {p.name}'s need_clip is False or "
72+
"the regularizer is set."
73+
)
74+
return params_grads
75+
76+
need_flatten_params.append(p)
77+
need_flatten_grads.append(g)
78+
79+
shape = [np.prod(p.shape) for p in need_flatten_params]
80+
81+
flatten_param = optimizer.helper.create_global_variable(
82+
name="flatten_param",
83+
persistable=True,
84+
dtype=need_flatten_params[0].dtype,
85+
shape=[np.sum(shape)],
86+
belong_to_optimizer=True,
87+
)
88+
89+
flatten_grad = optimizer.helper.create_global_variable(
90+
name="flatten_grad",
91+
persistable=True,
92+
dtype=need_flatten_grads[0].dtype,
93+
shape=[np.sum(shape)],
94+
belong_to_optimizer=True,
95+
)
96+
97+
flatten_param.stop_gradient = False
98+
# In the final state of the dynamic graph, the `coalesce_tensor` op
99+
# does not support passing the output as an input into the op in
100+
# temporary, so _legacy_C_ops is temporarily used here.
101+
# `use_align` is set to false, which is different from the behavior
102+
# under static graphs. `use_align` can be set to true after calling
103+
# the coalesce_tensor op of the final state (_C_ops).
104+
paddle._legacy_C_ops.coalesce_tensor(
105+
need_flatten_params,
106+
need_flatten_params,
107+
flatten_param,
108+
"copy_data",
109+
True,
110+
"use_align",
111+
False,
112+
"dtype",
113+
need_flatten_params[0].dtype,
114+
)
115+
116+
paddle._legacy_C_ops.coalesce_tensor(
117+
need_flatten_grads,
118+
need_flatten_grads,
119+
flatten_grad,
120+
"copy_data",
121+
True,
122+
"use_align",
123+
False,
124+
"dtype",
125+
need_flatten_grads[0].dtype,
126+
)
127+
return [(flatten_param, flatten_grad)]

paddlenlp/trainer/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,11 @@ def train(
594594
self._total_loss_scalar = 0.0
595595
self._globalstep_last_logged = self.state.global_step
596596

597+
if self.args.device == "npu" and self.args.flatten_param_grads:
598+
from .plugins.npu_plugin import npu_accelerate_plugin
599+
600+
npu_accelerate_plugin(self.optimizer)
601+
597602
for epoch in range(epochs_trained, num_train_epochs):
598603
if isinstance(train_dataloader, paddle.io.DataLoader) and isinstance(
599604
train_dataloader.batch_sampler, DistributedBatchSampler

paddlenlp/trainer/training_args.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ class TrainingArguments:
263263
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
264264
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
265265
scripts](https://github.com/PaddlePaddle/PaddleNLP/tree/develop/examples) for more details.
266+
flatten_param_grads (`bool`, *optional*):
267+
Whether use flatten_param_grads method in optimizer, only used on NPU devices. Default is `False`.
266268
"""
267269

268270
output_dir: str = field(
@@ -496,6 +498,10 @@ class TrainingArguments:
496498
skip_memory_metrics: bool = field(
497499
default=True, metadata={"help": "Whether or not to skip adding of memory profiler reports to metrics."}
498500
)
501+
flatten_param_grads: Optional[bool] = field(
502+
default=False,
503+
metadata={"help": "Whether use flatten_param_grads method in optimizer, only used on NPU devices."},
504+
)
499505

500506
def __post_init__(self):
501507
env_local_rank = int(os.environ.get("PADDLE_RANK_IN_NODE", -1))
@@ -624,6 +630,9 @@ def __post_init__(self):
624630
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
625631
)
626632

633+
if self.flatten_param_grads and self.device != "npu":
634+
raise ValueError("flatten_param_grads can only be used on npu devices in temporary.")
635+
627636
def __str__(self):
628637
self_as_dict = asdict(self)
629638
self_as_dict = {k: f"<{k.upper()}>" if k.endswith("_token") else v for k, v in self_as_dict.items()}

0 commit comments

Comments
 (0)