Skip to content

Commit 50bfe42

Browse files
authored
[cherry-pick 2.3] Add fused_multi_transformer op to optimize transformer generation performance (#42311)
* Add fused_multi_transformer op to optimize transformer generation performance (#41814) * fix fused_multi_transformer compile failed in cuda arch < sm53 (#42315) * fix ci timeout
1 parent 765fbb5 commit 50bfe42

14 files changed

+3040
-2
lines changed

paddle/fluid/operators/fused/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ register_operators(EXCLUDES
1919
fused_attention_op
2020
fused_transformer_op
2121
fused_feedforward_op
22+
fused_multi_transformer_op
2223
resnet_unit_op
2324
fused_gemm_epilogue_op)
2425

@@ -73,6 +74,7 @@ if (WITH_GPU OR WITH_ROCM)
7374
op_library(fused_feedforward_op)
7475
# fused_attention_op
7576
op_library(fused_attention_op)
77+
op_library(fused_multi_transformer_op)
7678
endif()
7779
# resnet_unit needs cudnn 8.0 above
7880
if ((NOT WITH_ROCM) AND (NOT ${CUDNN_VERSION} VERSION_LESS 8000))
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include <memory>
16+
#include <string>
17+
#include "paddle/fluid/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
22+
using Tensor = framework::Tensor;
23+
24+
class FusedMultiTransformerOp : public framework::OperatorWithKernel {
25+
private:
26+
static constexpr const char *OpName = "FusedMultiTransformerOp";
27+
28+
public:
29+
using framework::OperatorWithKernel::OperatorWithKernel;
30+
31+
void InferShape(framework::InferShapeContext *ctx) const override {
32+
#define CHECK_INPUT(name) \
33+
OP_INOUT_CHECK(ctx->HasInput(#name), "Input", #name, OpName)
34+
#define CHECK_INPUTS(name) \
35+
OP_INOUT_CHECK(ctx->HasInputs(#name), "Input", #name, OpName)
36+
#define CHECK_OUTPUT(name) \
37+
OP_INOUT_CHECK(ctx->HasOutput(#name), "Output", #name, OpName)
38+
#define CHECK_OUTPUTS(name) \
39+
OP_INOUT_CHECK(ctx->HasOutputs(#name), "Output", #name, OpName)
40+
41+
CHECK_INPUT(X);
42+
43+
// attention
44+
CHECK_INPUTS(QKVW);
45+
CHECK_INPUTS(OutLinearW);
46+
47+
if (ctx->HasInput("TimeStep")) {
48+
CHECK_INPUTS(CacheKV);
49+
}
50+
51+
if (ctx->HasInputs("CacheKV")) {
52+
CHECK_OUTPUTS(CacheKVOut);
53+
}
54+
55+
// ffn
56+
CHECK_INPUTS(FFN1Weight);
57+
CHECK_INPUTS(FFN2Weight);
58+
59+
CHECK_OUTPUT(Out);
60+
61+
// x: qkv's input [batch_size, seq_len, dim_embed]
62+
// y: qkv's weight: [3, num_head, dim_head, dim_embed]
63+
auto x_dim = ctx->GetInputDim("X");
64+
auto y_dim = ctx->GetInputsDim("QKVW")[0];
65+
PADDLE_ENFORCE_EQ(x_dim.size(), 3, platform::errors::InvalidArgument(
66+
"The dimensions of x must be 3"
67+
"(batch_size, seq_len, dim_embed),"
68+
"but received dimensions of"
69+
"Input is [%d]",
70+
x_dim.size()));
71+
PADDLE_ENFORCE_EQ(y_dim.size(), 4,
72+
platform::errors::InvalidArgument(
73+
"The dimensions of qkv_weight must be 4"
74+
"(3, num_head, dim_head, dim_embed),"
75+
"but received dimensions of"
76+
"Input is [%d]",
77+
y_dim.size()));
78+
PADDLE_ENFORCE_EQ(x_dim[2], y_dim[3],
79+
platform::errors::InvalidArgument(
80+
"ShapeError: the dimension of x_dim[2] and y_dim[3]"
81+
"must be equal. But received: the shape "
82+
"of input x = [%s], and the shape of "
83+
"input qkv_weight = [%s]",
84+
x_dim, y_dim));
85+
86+
if (ctx->Attrs().Get<int>("ring_id") == -1) {
87+
PADDLE_ENFORCE_EQ(y_dim[1] * y_dim[2], y_dim[3],
88+
platform::errors::InvalidArgument(
89+
"The dimensions of qkv_weight must be 4"
90+
"(3, num_head, dim_head, dim_embed),"
91+
"and must satisfy the limitations: "
92+
"(num_head * dim_head == dim_embed)"));
93+
}
94+
95+
if (ctx->HasInputs("CacheKV")) {
96+
// [2, batch_size, num_head, max_seq_len, head_size]
97+
const auto &c_dims = ctx->GetInputsDim("CacheKV");
98+
const auto &c_dim = c_dims[0];
99+
100+
PADDLE_ENFORCE_EQ(
101+
c_dim.size(), 5,
102+
paddle::platform::errors::InvalidArgument(
103+
"The CacheKV must be 5 dims, but got %d", c_dim.size()));
104+
PADDLE_ENFORCE_EQ(c_dim[0], 2,
105+
paddle::platform::errors::InvalidArgument(
106+
"The first dim of CacheKV must be 2, but got %d",
107+
c_dim[0])); // 2
108+
PADDLE_ENFORCE_EQ(c_dim[1], x_dim[0],
109+
paddle::platform::errors::InvalidArgument(
110+
"The second dim of CacheKV must be equal with "
111+
"batch size %d, but got %d",
112+
x_dim[0], c_dim[1])); // batch_size
113+
PADDLE_ENFORCE_EQ(c_dim[2], y_dim[1],
114+
paddle::platform::errors::InvalidArgument(
115+
"The third dim of CacheKV must be equal with num "
116+
"head %d, but got %d",
117+
y_dim[1], c_dim[2])); // num_head
118+
PADDLE_ENFORCE_GT(
119+
c_dim[3], 0,
120+
paddle::platform::errors::InvalidArgument(
121+
"The forth dim of CacheKV must be greater than 0, but got %d",
122+
c_dim[3])); // cache_seq_len
123+
PADDLE_ENFORCE_EQ(c_dim[4], y_dim[2],
124+
paddle::platform::errors::InvalidArgument(
125+
"The fifth dim of CacheKV must be equal with head "
126+
"size %d, but got %d",
127+
y_dim[2], c_dim[4])); // head_size
128+
}
129+
130+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
131+
}
132+
133+
protected:
134+
framework::OpKernelType GetExpectedKernelType(
135+
const framework::ExecutionContext &ctx) const override {
136+
return framework::OpKernelType(
137+
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
138+
}
139+
140+
framework::OpKernelType GetKernelTypeForVar(
141+
const std::string &var_name, const Tensor &tensor,
142+
const framework::OpKernelType &expected_kernel_type) const override {
143+
if (var_name == "TimeStep") {
144+
VLOG(10) << "var_name:" << var_name << " need not to transform";
145+
return expected_kernel_type;
146+
}
147+
return framework::OpKernelType(expected_kernel_type.data_type_,
148+
tensor.place(), tensor.layout());
149+
}
150+
};
151+
152+
class FusedMultiTransformerOpOpMaker
153+
: public framework::OpProtoAndCheckerMaker {
154+
public:
155+
void Make() override {
156+
AddInput("X", "The input tensor.");
157+
AddInput("LnScale",
158+
"Scale is a 1-dimensional tensor of size "
159+
"H. Here, H represents the last dimension of its input tensor.")
160+
.AsDuplicable();
161+
AddInput("LnBias",
162+
"Bias is a 1-dimensional tensor of size "
163+
"H. Here, H represents the last dimension of its input tensor.")
164+
.AsDuplicable();
165+
AddInput("QKVW", "The qkv weight tensor.").AsDuplicable();
166+
AddInput("QKVBias", "The qkv bias tensor.").AsDispensable().AsDuplicable();
167+
AddInput("CacheKV", "(optional) The cached KV for generation inference.")
168+
.AsDispensable()
169+
.AsDuplicable();
170+
AddInput("TimeStep",
171+
"(optional, int) The time step for generation inference.")
172+
.AsDispensable();
173+
AddInput("SrcMask", "(optional) The attention mask tensor in fmha.")
174+
.AsDispensable();
175+
AddInput("OutLinearW", "The out_linear weight tensor.").AsDuplicable();
176+
AddInput("OutLinearBias", "The out_linear bias tensor.")
177+
.AsDispensable()
178+
.AsDuplicable();
179+
180+
AddInput("FFNLnScale", "The layer_norm scale of FusedFeedForward op")
181+
.AsDuplicable();
182+
AddInput("FFNLnBias", "The layer_norm bias of FusedFeedForward op")
183+
.AsDuplicable();
184+
AddInput("FFN1Weight", "The linear1 weight of FusedFeedForward op")
185+
.AsDuplicable();
186+
AddInput("FFN1Bias", "The linear1 bias of FusedFeedForward op")
187+
.AsDispensable()
188+
.AsDuplicable();
189+
AddInput("FFN2Weight", "The linear2 weight of FusedFeedForward op")
190+
.AsDuplicable();
191+
AddInput("FFN2Bias", "The linear2 bias input of FusedFeedForward op")
192+
.AsDispensable()
193+
.AsDuplicable();
194+
195+
AddOutput("CacheKVOut", "The updated cache KV. Inplace with CacheKV")
196+
.AsDispensable()
197+
.AsDuplicable();
198+
AddOutput("Out", "Result after multi .");
199+
200+
AddAttr<bool>("pre_layer_norm",
201+
"if true, the attention op uses pre_layer_norm architecure, "
202+
"else, uses post_layer_norm architecuture. "
203+
"[default true].")
204+
.SetDefault(true);
205+
AddAttr<float>("epsilon",
206+
"Constant for numerical stability [default 1e-5].")
207+
.SetDefault(1e-5)
208+
.AddCustomChecker([](const float &epsilon) {
209+
PADDLE_ENFORCE_EQ(epsilon >= 0.0f && epsilon <= 0.001f, true,
210+
platform::errors::InvalidArgument(
211+
"'epsilon' in Op(LayerNorm) should be between"
212+
"0.0 and 0.001, But received [%s].",
213+
epsilon));
214+
});
215+
216+
AddAttr<float>("dropout_rate", "Probability of setting units to zero.")
217+
.SetDefault(.5f)
218+
.AddCustomChecker([](const float &drop_p) {
219+
PADDLE_ENFORCE_EQ(drop_p >= 0.0f && drop_p <= 1.0f, true,
220+
platform::errors::InvalidArgument(
221+
"'dropout_rate' must be between 0.0 and 1.0."));
222+
});
223+
224+
AddAttr<bool>("dropout_is_test",
225+
"(bool, default false) Set to true for inference only, false "
226+
"for training. Some layers may run faster when this is true.")
227+
.SetDefault(false);
228+
AddAttr<std::string>(
229+
"dropout_implementation",
230+
"[\"downgrade_in_infer\"|\"upscale_in_train\"]"
231+
"The meaning is the same as 'attn_dropout_implementation'.")
232+
.SetDefault("downgrade_in_infer")
233+
.AddCustomChecker([](const std::string &type) {
234+
PADDLE_ENFORCE_EQ(
235+
type == "downgrade_in_infer" || type == "upscale_in_train", true,
236+
platform::errors::InvalidArgument(
237+
"dropout_implementation can only be downgrade_in_infer or "
238+
"upscale_in_train"));
239+
});
240+
AddAttr<std::string>("act_method", "act_method").SetDefault("gelu");
241+
242+
AddAttr<int>(
243+
"ring_id",
244+
"ring id for tensor model parallel. distributed training and inference")
245+
.SetDefault(-1);
246+
247+
AddComment(R"DOC(fused multi transformer layers op)DOC");
248+
}
249+
};
250+
251+
} // namespace operators
252+
} // namespace paddle
253+
254+
namespace ops = paddle::operators;
255+
REGISTER_OPERATOR(
256+
fused_multi_transformer, ops::FusedMultiTransformerOp,
257+
ops::FusedMultiTransformerOpOpMaker,
258+
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
259+
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

0 commit comments

Comments
 (0)