Skip to content

Commit 9817379

Browse files
[NPU] ATB Kernel for Ernie-4.5 NPU (#1869)
1 parent 440e9be commit 9817379

10 files changed

+2407
-0
lines changed
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifdef PADDLE_WITH_ATB
16+
17+
#include "fused_fapa_attention.h" // NOLINT
18+
19+
#include <cmath>
20+
21+
#include "glog/logging.h"
22+
#include "linear.h" // NOLINT
23+
#include "qkv_split.h" // NOLINT
24+
25+
namespace atb_layers {
26+
27+
void CreateFaPaAttention(const FaPaAttentionParam &param,
28+
atb::Operation **operation) {
29+
uint64_t TENSOR_ID = 0;
30+
31+
uint64_t INPUT_QKV_OUT = TENSOR_ID++;
32+
33+
uint64_t INPUT_COS = param.use_alibi ? 0 : TENSOR_ID++;
34+
uint64_t INPUT_SIN = param.use_alibi ? 0 : TENSOR_ID++;
35+
uint64_t INPUT_MASK = param.is_prefill || param.use_alibi ? TENSOR_ID++ : 0;
36+
uint64_t INPUT_CACHE_K = TENSOR_ID++;
37+
uint64_t INPUT_CACHE_V = TENSOR_ID++;
38+
uint64_t INPUT_SLOTS = TENSOR_ID++;
39+
uint64_t INPUT_BLOCK_TABLES = !param.is_prefill ? TENSOR_ID++ : 0;
40+
uint64_t INPUT_SEQLEN = TENSOR_ID++;
41+
uint64_t INPUT_BATCH_STATUS = !param.is_prefill ? TENSOR_ID++ : INPUT_SEQLEN;
42+
43+
uint64_t OUTPUT = TENSOR_ID++;
44+
45+
uint64_t INTERMEDIATE_Q = TENSOR_ID++;
46+
uint64_t INTERMEDIATE_K = TENSOR_ID++;
47+
uint64_t INTERMEDIATE_V = TENSOR_ID++;
48+
uint64_t INTERMEDIATE_EMB_Q = TENSOR_ID++;
49+
uint64_t INTERMEDIATE_EMB_K = TENSOR_ID++;
50+
51+
uint64_t nodeIdx = 0;
52+
atb::GraphParam opGraph;
53+
opGraph.name = "FaPaAttentionOperation";
54+
opGraph.inTensorNum = INPUT_BATCH_STATUS - INPUT_QKV_OUT + 1;
55+
opGraph.outTensorNum = 1;
56+
opGraph.internalTensorNum = INTERMEDIATE_EMB_K - INTERMEDIATE_Q + 1;
57+
if (param.use_alibi) {
58+
opGraph.nodes.resize(3);
59+
} else {
60+
opGraph.nodes.resize(4);
61+
}
62+
63+
// split q,k,v
64+
{
65+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
66+
atb_layers::QKVSplitParam opParam;
67+
opParam.head_num = param.head_num;
68+
opParam.kv_head_num = param.kv_head_num;
69+
opParam.head_dim = param.head_dim;
70+
atb::CreateOperation(opParam, &opNode.operation);
71+
opNode.inTensorIds = {INPUT_QKV_OUT};
72+
opNode.outTensorIds = {INTERMEDIATE_Q, INTERMEDIATE_K, INTERMEDIATE_V};
73+
}
74+
75+
// rope
76+
if (!param.use_alibi) {
77+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
78+
atb::infer::RopeParam opParam;
79+
opParam.rotaryCoeff = param.rope_neox ? param.head_dim : 2;
80+
atb::CreateOperation(opParam, &opNode.operation);
81+
opNode.inTensorIds = {
82+
INTERMEDIATE_Q, INTERMEDIATE_K, INPUT_COS, INPUT_SIN, INPUT_SEQLEN};
83+
opNode.outTensorIds = {INTERMEDIATE_EMB_Q, INTERMEDIATE_EMB_K};
84+
}
85+
86+
// write kv
87+
{
88+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
89+
atb::infer::ReshapeAndCacheParam opParam;
90+
atb::CreateOperation(opParam, &opNode.operation);
91+
opNode.inTensorIds = {INTERMEDIATE_EMB_K,
92+
INTERMEDIATE_V,
93+
INPUT_CACHE_K,
94+
INPUT_CACHE_V,
95+
INPUT_SLOTS};
96+
opNode.outTensorIds = {INPUT_CACHE_K, INPUT_CACHE_V}; // write in place
97+
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
98+
opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape,
99+
atb::Dims &newShape) {
100+
newShape.dimNum = 3;
101+
newShape.dims[0] = oldShape.dims[0];
102+
newShape.dims[1] = param.kv_head_num;
103+
newShape.dims[2] = param.head_dim;
104+
};
105+
opNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape,
106+
atb::Dims &newShape) {
107+
newShape.dimNum = 3;
108+
newShape.dims[0] = oldShape.dims[0];
109+
newShape.dims[1] = param.kv_head_num;
110+
newShape.dims[2] = param.head_dim;
111+
};
112+
opNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape,
113+
atb::Dims &newShape) {
114+
newShape.dimNum = 4;
115+
newShape.dims[0] = oldShape.dims[0];
116+
newShape.dims[1] = oldShape.dims[2];
117+
newShape.dims[2] = oldShape.dims[1];
118+
newShape.dims[3] = oldShape.dims[3];
119+
};
120+
opNode.inTensorReshapeFuncs[3] = [=](const atb::Dims &oldShape,
121+
atb::Dims &newShape) {
122+
newShape.dimNum = 4;
123+
newShape.dims[0] = oldShape.dims[0];
124+
newShape.dims[1] = oldShape.dims[2];
125+
newShape.dims[2] = oldShape.dims[1];
126+
newShape.dims[3] = oldShape.dims[3];
127+
};
128+
}
129+
130+
if (param.is_prefill) {
131+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
132+
atb::infer::SelfAttentionParam opParam;
133+
opParam.headNum = param.head_num;
134+
opParam.kvHeadNum = param.kv_head_num;
135+
opParam.qkScale = 1.0f / sqrt(param.head_dim);
136+
opParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER;
137+
opParam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM;
138+
if (param.use_alibi) {
139+
opParam.isTriuMask = 0;
140+
opParam.maskType =
141+
atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI;
142+
} else {
143+
opParam.isTriuMask = 1;
144+
}
145+
atb::CreateOperation(opParam, &opNode.operation);
146+
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
147+
INTERMEDIATE_EMB_K,
148+
INTERMEDIATE_V,
149+
INPUT_MASK,
150+
INPUT_SEQLEN};
151+
LOG(INFO) << "OUTPUT fa **************" << OUTPUT;
152+
opNode.outTensorIds = {OUTPUT};
153+
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
154+
} else {
155+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
156+
atb::infer::PagedAttentionParam opParam;
157+
opParam.headNum = param.head_num;
158+
opParam.qkScale = 1.0f / sqrt(param.head_dim);
159+
opParam.kvHeadNum = param.kv_head_num;
160+
if (param.use_alibi) {
161+
opParam.maskType =
162+
atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI;
163+
} else {
164+
opParam.maskType = atb::infer::PagedAttentionParam::MaskType::UNDEFINED;
165+
}
166+
opParam.batchRunStatusEnable = true;
167+
168+
atb::CreateOperation(opParam, &opNode.operation);
169+
170+
if (param.use_alibi) {
171+
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
172+
INPUT_CACHE_K,
173+
INPUT_CACHE_V,
174+
INPUT_BLOCK_TABLES,
175+
INPUT_SEQLEN,
176+
INPUT_MASK,
177+
INPUT_BATCH_STATUS};
178+
} else {
179+
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
180+
INPUT_CACHE_K,
181+
INPUT_CACHE_V,
182+
INPUT_BLOCK_TABLES,
183+
INPUT_SEQLEN,
184+
INPUT_BATCH_STATUS};
185+
}
186+
187+
opNode.outTensorIds = {OUTPUT};
188+
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
189+
opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape,
190+
atb::Dims &newShape) {
191+
newShape.dimNum = 3;
192+
newShape.dims[0] = oldShape.dims[0];
193+
newShape.dims[1] = param.head_num;
194+
newShape.dims[2] = param.head_dim;
195+
};
196+
}
197+
198+
atb::CreateOperation(opGraph, operation);
199+
}
200+
201+
} // namespace atb_layers
202+
203+
#endif
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#ifdef PADDLE_WITH_ATB
17+
18+
#include "atb/atb_infer.h"
19+
20+
namespace atb_layers {
21+
22+
struct FaPaAttentionParam {
23+
int64_t head_num;
24+
int64_t kv_head_num;
25+
int64_t head_dim;
26+
bool use_alibi{false};
27+
bool rope_neox{false};
28+
bool is_prefill;
29+
};
30+
31+
void CreateFaPaAttention(const FaPaAttentionParam& param,
32+
atb::Operation** operation);
33+
34+
} // namespace atb_layers
35+
36+
namespace atb {
37+
template <>
38+
inline Status CreateOperation(const atb_layers::FaPaAttentionParam& opParam,
39+
Operation** operation) {
40+
atb_layers::CreateFaPaAttention(opParam, operation);
41+
return ErrorType::NO_ERROR;
42+
}
43+
} // namespace atb
44+
45+
#endif
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifdef PADDLE_WITH_ATB
16+
17+
#include "fused_rms_norm.h" // NOLINT
18+
19+
namespace atb_layers {
20+
21+
void CreateRmsNorm(const RmsNormParam &param, atb::Operation **operation) {
22+
uint64_t TENSOR_ID = 0;
23+
uint64_t INPUT = TENSOR_ID++;
24+
uint64_t INPUT_WEIGHT = TENSOR_ID++;
25+
uint64_t INPUT_RESIDUAL = param.has_residual ? TENSOR_ID++ : INPUT_WEIGHT;
26+
uint64_t OUTPUT = TENSOR_ID++;
27+
uint64_t OUTPUT_RESIDUAL = param.has_residual ? TENSOR_ID++ : OUTPUT;
28+
29+
uint64_t nodeIdx = 0;
30+
atb::GraphParam opGraph;
31+
opGraph.name = "RmsNormOperation";
32+
opGraph.internalTensorNum = 0;
33+
34+
if (param.has_residual) {
35+
opGraph.inTensorNum = 3;
36+
opGraph.outTensorNum = 2;
37+
opGraph.nodes.resize(2);
38+
} else {
39+
opGraph.inTensorNum = 2;
40+
opGraph.outTensorNum = 1;
41+
opGraph.nodes.resize(1);
42+
}
43+
44+
if (param.has_residual) {
45+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
46+
atb::infer::ElewiseParam opParam;
47+
opParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD;
48+
atb::CreateOperation(opParam, &opNode.operation);
49+
opNode.inTensorIds = {INPUT, INPUT_RESIDUAL};
50+
opNode.outTensorIds = {OUTPUT_RESIDUAL};
51+
}
52+
53+
{
54+
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
55+
atb::infer::RmsNormParam opParam;
56+
opParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM;
57+
opParam.normParam.epsilon = param.epsilon;
58+
atb::CreateOperation(opParam, &opNode.operation);
59+
if (param.has_residual) {
60+
opNode.inTensorIds = {OUTPUT_RESIDUAL, INPUT_WEIGHT};
61+
} else {
62+
opNode.inTensorIds = {INPUT, INPUT_WEIGHT};
63+
}
64+
opNode.outTensorIds = {OUTPUT};
65+
}
66+
67+
atb::CreateOperation(opGraph, operation);
68+
}
69+
70+
} // namespace atb_layers
71+
72+
#endif
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#ifdef PADDLE_WITH_ATB
17+
18+
#include "atb/atb_infer.h"
19+
20+
namespace atb_layers {
21+
22+
struct RmsNormParam {
23+
float epsilon{1.0};
24+
bool has_residual{false};
25+
};
26+
27+
void CreateRmsNorm(const RmsNormParam& param, atb::Operation** operation);
28+
29+
} // namespace atb_layers
30+
31+
namespace atb {
32+
template <>
33+
inline Status CreateOperation(const atb_layers::RmsNormParam& opParam,
34+
Operation** operation) {
35+
atb_layers::CreateRmsNorm(opParam, operation);
36+
return ErrorType::NO_ERROR;
37+
}
38+
} // namespace atb
39+
40+
#endif

0 commit comments

Comments
 (0)