Skip to content

[NPU] Add Fapa atb #1763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifdef PADDLE_WITH_ATB

#include "fused_fapa_attention.h" // NOLINT
#include "qkv_split.h" // NOLINT
#include "linear.h" // NOLINT
#include "glog/logging.h"

#include <cmath>

namespace atb_layers {

void CreateFaPaAttention(const FaPaAttentionParam& param, atb::Operation **operation) {
uint64_t TENSOR_ID = 0;

uint64_t INPUT_HIDDEN_STATES = TENSOR_ID++;
uint64_t INPUT_QKV_WEIGHT = TENSOR_ID++;
uint64_t INPUT_QKV_BIAS = param.has_qkv_bias ? TENSOR_ID++ : 0;
uint64_t INPUT_QKV_DEQSCALE = param.use_matmul_int8 ? TENSOR_ID++ : 0;
uint64_t INPUT_QKV_DEQOFFSET = param.use_matmul_int8 ? TENSOR_ID++ : 0;

uint64_t INPUT_COS = param.use_alibi ? 0 : TENSOR_ID++;
uint64_t INPUT_SIN = param.use_alibi ? 0 : TENSOR_ID++;
uint64_t INPUT_MASK = param.is_prefill || param.use_alibi ? TENSOR_ID++ : 0;
uint64_t INPUT_CACHE_K = TENSOR_ID++;
uint64_t INPUT_CACHE_V = TENSOR_ID++;
uint64_t INPUT_SLOTS = TENSOR_ID++;
uint64_t INPUT_BLOCK_TABLES = !param.is_prefill ? TENSOR_ID++ : 0;
uint64_t INPUT_SEQLEN = TENSOR_ID++;
uint64_t INPUT_BATCH_STATUS = !param.is_prefill ? TENSOR_ID++ : INPUT_SEQLEN;

uint64_t OUTPUT = TENSOR_ID++;

uint64_t INTERMEDIATE_QKV_OUT = TENSOR_ID++;
uint64_t INTERMEDIATE_Q = TENSOR_ID++;
uint64_t INTERMEDIATE_K = TENSOR_ID++;
uint64_t INTERMEDIATE_V = TENSOR_ID++;
uint64_t INTERMEDIATE_EMB_Q = TENSOR_ID++;
uint64_t INTERMEDIATE_EMB_K = TENSOR_ID++;



uint64_t nodeIdx = 0;
atb::GraphParam opGraph;
opGraph.name = "FaPaAttentionOperation";
opGraph.inTensorNum = INPUT_BATCH_STATUS - INPUT_HIDDEN_STATES + 1;
opGraph.outTensorNum = 1;
opGraph.internalTensorNum = INTERMEDIATE_EMB_K - INTERMEDIATE_QKV_OUT + 1;
if (param.use_alibi) {
opGraph.nodes.resize(4);
} else {
opGraph.nodes.resize(5);
}

// qkv
{
LOG(INFO) << "beging attention **************" ;
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb_layers::LinearParam opParam;
opParam.trans_weight = param.trans_qkv_weight;
opParam.has_bias = param.has_qkv_bias;
opParam.input_quant = param.use_matmul_int8;
opParam.input_quant_scale = param.qkv_quant_scale;
opParam.input_quant_offset = 0;
opParam.input_smooth_quant = false;
opParam.has_dequant_offset = param.use_matmul_int8;
atb::CreateOperation(opParam, &opNode.operation);
if (param.has_qkv_bias && param.use_matmul_int8) {
opNode.inTensorIds = {INPUT_HIDDEN_STATES,
INPUT_QKV_WEIGHT,
INPUT_QKV_BIAS,
INPUT_QKV_DEQSCALE,
INPUT_QKV_DEQOFFSET};
} else if (param.has_qkv_bias) {
opNode.inTensorIds = {
INPUT_HIDDEN_STATES, INPUT_QKV_WEIGHT, INPUT_QKV_BIAS};
} else if (param.use_matmul_int8) {
opNode.inTensorIds = {INPUT_HIDDEN_STATES,
INPUT_QKV_WEIGHT,
INPUT_QKV_DEQSCALE,
INPUT_QKV_DEQOFFSET};
} else {
opNode.inTensorIds = {INPUT_HIDDEN_STATES, INPUT_QKV_WEIGHT};
}
opNode.outTensorIds = {INTERMEDIATE_QKV_OUT};
}

// split q,k,v
{
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb_layers::QKVSplitParam opParam;
opParam.head_num = param.head_num;
opParam.kv_head_num = param.kv_head_num;
opParam.head_dim = param.head_dim;
atb::CreateOperation(opParam, &opNode.operation);
opNode.inTensorIds = {INTERMEDIATE_QKV_OUT};
opNode.outTensorIds = {INTERMEDIATE_Q, INTERMEDIATE_K, INTERMEDIATE_V};
}

// rope
if (!param.use_alibi) {
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb::infer::RopeParam opParam;
opParam.rotaryCoeff = param.rope_neox ? param.head_dim : 2;
atb::CreateOperation(opParam, &opNode.operation);
opNode.inTensorIds = {
INTERMEDIATE_Q, INTERMEDIATE_K, INPUT_COS, INPUT_SIN, INPUT_SEQLEN};
opNode.outTensorIds = {INTERMEDIATE_EMB_Q, INTERMEDIATE_EMB_K};
}

// write kv
{
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb::infer::ReshapeAndCacheParam opParam;
atb::CreateOperation(opParam, &opNode.operation);
opNode.inTensorIds = {INTERMEDIATE_EMB_K,
INTERMEDIATE_V,
INPUT_CACHE_K,
INPUT_CACHE_V,
INPUT_SLOTS};
opNode.outTensorIds = {INPUT_CACHE_K, INPUT_CACHE_V}; // write in place
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 3;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = param.kv_head_num;
newShape.dims[2] = param.head_dim;
};
opNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 3;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = param.kv_head_num;
newShape.dims[2] = param.head_dim;
};
opNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 4;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = oldShape.dims[2];
newShape.dims[2] = oldShape.dims[1];
newShape.dims[3] = oldShape.dims[3];
};
opNode.inTensorReshapeFuncs[3] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 4;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = oldShape.dims[2];
newShape.dims[2] = oldShape.dims[1];
newShape.dims[3] = oldShape.dims[3];
};
}

if (param.is_prefill) {
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb::infer::SelfAttentionParam opParam;
opParam.headNum = param.head_num;
opParam.kvHeadNum = param.kv_head_num;
opParam.qkScale = 1.0f / sqrt(param.head_dim);
opParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER;
opParam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM;
if (param.use_alibi) {
opParam.isTriuMask = 0;
opParam.maskType =
atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI;
} else {
opParam.isTriuMask = 1;
}
atb::CreateOperation(opParam, &opNode.operation);
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
INTERMEDIATE_EMB_K,
INTERMEDIATE_V,
INPUT_MASK,
INPUT_SEQLEN};
LOG(INFO) << "OUTPUT fa **************" <<OUTPUT;
opNode.outTensorIds = {OUTPUT};
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
} else {
atb::Node &opNode = opGraph.nodes.at(nodeIdx++);
atb::infer::PagedAttentionParam opParam;
opParam.headNum = param.head_num;
opParam.qkScale = 1.0f / sqrt(param.head_dim);
opParam.kvHeadNum = param.kv_head_num;
if (param.use_alibi) {
opParam.maskType =
atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI;
} else {
opParam.maskType = atb::infer::PagedAttentionParam::MaskType::UNDEFINED;
}
opParam.batchRunStatusEnable = true;

atb::CreateOperation(opParam, &opNode.operation);

if (param.use_alibi) {
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
INPUT_CACHE_K,
INPUT_CACHE_V,
INPUT_BLOCK_TABLES,
INPUT_SEQLEN,
INPUT_MASK,
INPUT_BATCH_STATUS};
} else {
opNode.inTensorIds = {INTERMEDIATE_EMB_Q,
INPUT_CACHE_K,
INPUT_CACHE_V,
INPUT_BLOCK_TABLES,
INPUT_SEQLEN,
INPUT_BATCH_STATUS};
}

LOG(INFO) << "OUTPUT pa **************"<< OUTPUT;
opNode.outTensorIds = {OUTPUT};
opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size());
opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 3;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = param.kv_head_num;
newShape.dims[2] = param.head_dim;
};
opNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 4;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = oldShape.dims[2];
newShape.dims[2] = oldShape.dims[1];
newShape.dims[3] = oldShape.dims[3];
};
opNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape,
atb::Dims &newShape) {
newShape.dimNum = 4;
newShape.dims[0] = oldShape.dims[0];
newShape.dims[1] = oldShape.dims[2];
newShape.dims[2] = oldShape.dims[1];
newShape.dims[3] = oldShape.dims[3];
};
}

opGraph.inferShapeFunc =
[=](const atb::SVector<atb::TensorDesc> &inTensorDescs,
atb::SVector<atb::TensorDesc> &outTensorDescs) {
outTensorDescs.resize(1);
outTensorDescs.at(0) = inTensorDescs.at(0);
return atb::NO_ERROR;
};

atb::CreateOperation(opGraph, operation);
}

} // namespace atb_layers

#endif
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#ifdef PADDLE_WITH_ATB

#include "atb/atb_infer.h"

namespace atb_layers {

struct FaPaAttentionParam {
int64_t head_num;
int64_t kv_head_num;
int64_t head_dim;
bool trans_qkv_weight;
bool has_qkv_bias{false};
bool use_matmul_int8{false};
float qkv_quant_scale{1.0f};
bool use_alibi{false};
bool rope_neox{false};
bool is_prefill;
};

void CreateFaPaAttention(const FaPaAttentionParam& param, atb::Operation** operation);

} // namespace atb_layers

namespace atb {
template <>
inline Status CreateOperation(const atb_layers::FaPaAttentionParam& opParam,
Operation** operation) {
atb_layers::CreateFaPaAttention(opParam, operation);
return ErrorType::NO_ERROR;
}
} // namespace atb

#endif
Loading