|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#ifdef PADDLE_WITH_ATB |
| 16 | + |
| 17 | +#include "fused_fapa_attention.h" // NOLINT |
| 18 | + |
| 19 | +#include <cmath> |
| 20 | + |
| 21 | +#include "glog/logging.h" |
| 22 | +#include "linear.h" // NOLINT |
| 23 | +#include "qkv_split.h" // NOLINT |
| 24 | + |
| 25 | +namespace atb_layers { |
| 26 | + |
| 27 | +void CreateFaPaAttention(const FaPaAttentionParam ¶m, |
| 28 | + atb::Operation **operation) { |
| 29 | + uint64_t TENSOR_ID = 0; |
| 30 | + |
| 31 | + uint64_t INPUT_QKV_OUT = TENSOR_ID++; |
| 32 | + |
| 33 | + uint64_t INPUT_COS = param.use_alibi ? 0 : TENSOR_ID++; |
| 34 | + uint64_t INPUT_SIN = param.use_alibi ? 0 : TENSOR_ID++; |
| 35 | + uint64_t INPUT_MASK = param.is_prefill || param.use_alibi ? TENSOR_ID++ : 0; |
| 36 | + uint64_t INPUT_CACHE_K = TENSOR_ID++; |
| 37 | + uint64_t INPUT_CACHE_V = TENSOR_ID++; |
| 38 | + uint64_t INPUT_SLOTS = TENSOR_ID++; |
| 39 | + uint64_t INPUT_BLOCK_TABLES = !param.is_prefill ? TENSOR_ID++ : 0; |
| 40 | + uint64_t INPUT_SEQLEN = TENSOR_ID++; |
| 41 | + uint64_t INPUT_BATCH_STATUS = !param.is_prefill ? TENSOR_ID++ : INPUT_SEQLEN; |
| 42 | + |
| 43 | + uint64_t OUTPUT = TENSOR_ID++; |
| 44 | + |
| 45 | + uint64_t INTERMEDIATE_Q = TENSOR_ID++; |
| 46 | + uint64_t INTERMEDIATE_K = TENSOR_ID++; |
| 47 | + uint64_t INTERMEDIATE_V = TENSOR_ID++; |
| 48 | + uint64_t INTERMEDIATE_EMB_Q = TENSOR_ID++; |
| 49 | + uint64_t INTERMEDIATE_EMB_K = TENSOR_ID++; |
| 50 | + |
| 51 | + uint64_t nodeIdx = 0; |
| 52 | + atb::GraphParam opGraph; |
| 53 | + opGraph.name = "FaPaAttentionOperation"; |
| 54 | + opGraph.inTensorNum = INPUT_BATCH_STATUS - INPUT_QKV_OUT + 1; |
| 55 | + opGraph.outTensorNum = 1; |
| 56 | + opGraph.internalTensorNum = INTERMEDIATE_EMB_K - INTERMEDIATE_Q + 1; |
| 57 | + if (param.use_alibi) { |
| 58 | + opGraph.nodes.resize(3); |
| 59 | + } else { |
| 60 | + opGraph.nodes.resize(4); |
| 61 | + } |
| 62 | + |
| 63 | + // split q,k,v |
| 64 | + { |
| 65 | + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); |
| 66 | + atb_layers::QKVSplitParam opParam; |
| 67 | + opParam.head_num = param.head_num; |
| 68 | + opParam.kv_head_num = param.kv_head_num; |
| 69 | + opParam.head_dim = param.head_dim; |
| 70 | + atb::CreateOperation(opParam, &opNode.operation); |
| 71 | + opNode.inTensorIds = {INPUT_QKV_OUT}; |
| 72 | + opNode.outTensorIds = {INTERMEDIATE_Q, INTERMEDIATE_K, INTERMEDIATE_V}; |
| 73 | + } |
| 74 | + |
| 75 | + // rope |
| 76 | + if (!param.use_alibi) { |
| 77 | + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); |
| 78 | + atb::infer::RopeParam opParam; |
| 79 | + opParam.rotaryCoeff = param.rope_neox ? param.head_dim : 2; |
| 80 | + atb::CreateOperation(opParam, &opNode.operation); |
| 81 | + opNode.inTensorIds = { |
| 82 | + INTERMEDIATE_Q, INTERMEDIATE_K, INPUT_COS, INPUT_SIN, INPUT_SEQLEN}; |
| 83 | + opNode.outTensorIds = {INTERMEDIATE_EMB_Q, INTERMEDIATE_EMB_K}; |
| 84 | + } |
| 85 | + |
| 86 | + // write kv |
| 87 | + { |
| 88 | + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); |
| 89 | + atb::infer::ReshapeAndCacheParam opParam; |
| 90 | + atb::CreateOperation(opParam, &opNode.operation); |
| 91 | + opNode.inTensorIds = {INTERMEDIATE_EMB_K, |
| 92 | + INTERMEDIATE_V, |
| 93 | + INPUT_CACHE_K, |
| 94 | + INPUT_CACHE_V, |
| 95 | + INPUT_SLOTS}; |
| 96 | + opNode.outTensorIds = {INPUT_CACHE_K, INPUT_CACHE_V}; // write in place |
| 97 | + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); |
| 98 | + opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, |
| 99 | + atb::Dims &newShape) { |
| 100 | + newShape.dimNum = 3; |
| 101 | + newShape.dims[0] = oldShape.dims[0]; |
| 102 | + newShape.dims[1] = param.kv_head_num; |
| 103 | + newShape.dims[2] = param.head_dim; |
| 104 | + }; |
| 105 | + opNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape, |
| 106 | + atb::Dims &newShape) { |
| 107 | + newShape.dimNum = 3; |
| 108 | + newShape.dims[0] = oldShape.dims[0]; |
| 109 | + newShape.dims[1] = param.kv_head_num; |
| 110 | + newShape.dims[2] = param.head_dim; |
| 111 | + }; |
| 112 | + opNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape, |
| 113 | + atb::Dims &newShape) { |
| 114 | + newShape.dimNum = 4; |
| 115 | + newShape.dims[0] = oldShape.dims[0]; |
| 116 | + newShape.dims[1] = oldShape.dims[2]; |
| 117 | + newShape.dims[2] = oldShape.dims[1]; |
| 118 | + newShape.dims[3] = oldShape.dims[3]; |
| 119 | + }; |
| 120 | + opNode.inTensorReshapeFuncs[3] = [=](const atb::Dims &oldShape, |
| 121 | + atb::Dims &newShape) { |
| 122 | + newShape.dimNum = 4; |
| 123 | + newShape.dims[0] = oldShape.dims[0]; |
| 124 | + newShape.dims[1] = oldShape.dims[2]; |
| 125 | + newShape.dims[2] = oldShape.dims[1]; |
| 126 | + newShape.dims[3] = oldShape.dims[3]; |
| 127 | + }; |
| 128 | + } |
| 129 | + |
| 130 | + if (param.is_prefill) { |
| 131 | + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); |
| 132 | + atb::infer::SelfAttentionParam opParam; |
| 133 | + opParam.headNum = param.head_num; |
| 134 | + opParam.kvHeadNum = param.kv_head_num; |
| 135 | + opParam.qkScale = 1.0f / sqrt(param.head_dim); |
| 136 | + opParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER; |
| 137 | + opParam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM; |
| 138 | + if (param.use_alibi) { |
| 139 | + opParam.isTriuMask = 0; |
| 140 | + opParam.maskType = |
| 141 | + atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI; |
| 142 | + } else { |
| 143 | + opParam.isTriuMask = 1; |
| 144 | + } |
| 145 | + atb::CreateOperation(opParam, &opNode.operation); |
| 146 | + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, |
| 147 | + INTERMEDIATE_EMB_K, |
| 148 | + INTERMEDIATE_V, |
| 149 | + INPUT_MASK, |
| 150 | + INPUT_SEQLEN}; |
| 151 | + LOG(INFO) << "OUTPUT fa **************" << OUTPUT; |
| 152 | + opNode.outTensorIds = {OUTPUT}; |
| 153 | + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); |
| 154 | + } else { |
| 155 | + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); |
| 156 | + atb::infer::PagedAttentionParam opParam; |
| 157 | + opParam.headNum = param.head_num; |
| 158 | + opParam.qkScale = 1.0f / sqrt(param.head_dim); |
| 159 | + opParam.kvHeadNum = param.kv_head_num; |
| 160 | + if (param.use_alibi) { |
| 161 | + opParam.maskType = |
| 162 | + atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI; |
| 163 | + } else { |
| 164 | + opParam.maskType = atb::infer::PagedAttentionParam::MaskType::UNDEFINED; |
| 165 | + } |
| 166 | + opParam.batchRunStatusEnable = true; |
| 167 | + |
| 168 | + atb::CreateOperation(opParam, &opNode.operation); |
| 169 | + |
| 170 | + if (param.use_alibi) { |
| 171 | + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, |
| 172 | + INPUT_CACHE_K, |
| 173 | + INPUT_CACHE_V, |
| 174 | + INPUT_BLOCK_TABLES, |
| 175 | + INPUT_SEQLEN, |
| 176 | + INPUT_MASK, |
| 177 | + INPUT_BATCH_STATUS}; |
| 178 | + } else { |
| 179 | + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, |
| 180 | + INPUT_CACHE_K, |
| 181 | + INPUT_CACHE_V, |
| 182 | + INPUT_BLOCK_TABLES, |
| 183 | + INPUT_SEQLEN, |
| 184 | + INPUT_BATCH_STATUS}; |
| 185 | + } |
| 186 | + |
| 187 | + opNode.outTensorIds = {OUTPUT}; |
| 188 | + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); |
| 189 | + opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, |
| 190 | + atb::Dims &newShape) { |
| 191 | + newShape.dimNum = 3; |
| 192 | + newShape.dims[0] = oldShape.dims[0]; |
| 193 | + newShape.dims[1] = param.head_num; |
| 194 | + newShape.dims[2] = param.head_dim; |
| 195 | + }; |
| 196 | + } |
| 197 | + |
| 198 | + atb::CreateOperation(opGraph, operation); |
| 199 | +} |
| 200 | + |
| 201 | +} // namespace atb_layers |
| 202 | + |
| 203 | +#endif |
0 commit comments