diff --git a/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.cc b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.cc new file mode 100644 index 00000000000..f8eaf7dee71 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.cc @@ -0,0 +1,203 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_ATB + +#include "fused_fapa_attention.h" // NOLINT + +#include + +#include "glog/logging.h" +#include "linear.h" // NOLINT +#include "qkv_split.h" // NOLINT + +namespace atb_layers { + +void CreateFaPaAttention(const FaPaAttentionParam ¶m, + atb::Operation **operation) { + uint64_t TENSOR_ID = 0; + + uint64_t INPUT_QKV_OUT = TENSOR_ID++; + + uint64_t INPUT_COS = param.use_alibi ? 0 : TENSOR_ID++; + uint64_t INPUT_SIN = param.use_alibi ? 0 : TENSOR_ID++; + uint64_t INPUT_MASK = param.is_prefill || param.use_alibi ? TENSOR_ID++ : 0; + uint64_t INPUT_CACHE_K = TENSOR_ID++; + uint64_t INPUT_CACHE_V = TENSOR_ID++; + uint64_t INPUT_SLOTS = TENSOR_ID++; + uint64_t INPUT_BLOCK_TABLES = !param.is_prefill ? TENSOR_ID++ : 0; + uint64_t INPUT_SEQLEN = TENSOR_ID++; + uint64_t INPUT_BATCH_STATUS = !param.is_prefill ? TENSOR_ID++ : INPUT_SEQLEN; + + uint64_t OUTPUT = TENSOR_ID++; + + uint64_t INTERMEDIATE_Q = TENSOR_ID++; + uint64_t INTERMEDIATE_K = TENSOR_ID++; + uint64_t INTERMEDIATE_V = TENSOR_ID++; + uint64_t INTERMEDIATE_EMB_Q = TENSOR_ID++; + uint64_t INTERMEDIATE_EMB_K = TENSOR_ID++; + + uint64_t nodeIdx = 0; + atb::GraphParam opGraph; + opGraph.name = "FaPaAttentionOperation"; + opGraph.inTensorNum = INPUT_BATCH_STATUS - INPUT_QKV_OUT + 1; + opGraph.outTensorNum = 1; + opGraph.internalTensorNum = INTERMEDIATE_EMB_K - INTERMEDIATE_Q + 1; + if (param.use_alibi) { + opGraph.nodes.resize(3); + } else { + opGraph.nodes.resize(4); + } + + // split q,k,v + { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb_layers::QKVSplitParam opParam; + opParam.head_num = param.head_num; + opParam.kv_head_num = param.kv_head_num; + opParam.head_dim = param.head_dim; + atb::CreateOperation(opParam, &opNode.operation); + opNode.inTensorIds = {INPUT_QKV_OUT}; + opNode.outTensorIds = {INTERMEDIATE_Q, INTERMEDIATE_K, INTERMEDIATE_V}; + } + + // rope + if (!param.use_alibi) { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::RopeParam opParam; + opParam.rotaryCoeff = param.rope_neox ? param.head_dim : 2; + atb::CreateOperation(opParam, &opNode.operation); + opNode.inTensorIds = { + INTERMEDIATE_Q, INTERMEDIATE_K, INPUT_COS, INPUT_SIN, INPUT_SEQLEN}; + opNode.outTensorIds = {INTERMEDIATE_EMB_Q, INTERMEDIATE_EMB_K}; + } + + // write kv + { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::ReshapeAndCacheParam opParam; + atb::CreateOperation(opParam, &opNode.operation); + opNode.inTensorIds = {INTERMEDIATE_EMB_K, + INTERMEDIATE_V, + INPUT_CACHE_K, + INPUT_CACHE_V, + INPUT_SLOTS}; + opNode.outTensorIds = {INPUT_CACHE_K, INPUT_CACHE_V}; // write in place + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); + opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, + atb::Dims &newShape) { + newShape.dimNum = 3; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = param.kv_head_num; + newShape.dims[2] = param.head_dim; + }; + opNode.inTensorReshapeFuncs[1] = [=](const atb::Dims &oldShape, + atb::Dims &newShape) { + newShape.dimNum = 3; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = param.kv_head_num; + newShape.dims[2] = param.head_dim; + }; + opNode.inTensorReshapeFuncs[2] = [=](const atb::Dims &oldShape, + atb::Dims &newShape) { + newShape.dimNum = 4; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[2]; + newShape.dims[2] = oldShape.dims[1]; + newShape.dims[3] = oldShape.dims[3]; + }; + opNode.inTensorReshapeFuncs[3] = [=](const atb::Dims &oldShape, + atb::Dims &newShape) { + newShape.dimNum = 4; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = oldShape.dims[2]; + newShape.dims[2] = oldShape.dims[1]; + newShape.dims[3] = oldShape.dims[3]; + }; + } + + if (param.is_prefill) { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::SelfAttentionParam opParam; + opParam.headNum = param.head_num; + opParam.kvHeadNum = param.kv_head_num; + opParam.qkScale = 1.0f / sqrt(param.head_dim); + opParam.calcType = atb::infer::SelfAttentionParam::CalcType::PA_ENCODER; + opParam.maskType = atb::infer::SelfAttentionParam::MASK_TYPE_NORM; + if (param.use_alibi) { + opParam.isTriuMask = 0; + opParam.maskType = + atb::infer::SelfAttentionParam::MaskType::MASK_TYPE_ALIBI; + } else { + opParam.isTriuMask = 1; + } + atb::CreateOperation(opParam, &opNode.operation); + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, + INTERMEDIATE_EMB_K, + INTERMEDIATE_V, + INPUT_MASK, + INPUT_SEQLEN}; + LOG(INFO) << "OUTPUT fa **************" << OUTPUT; + opNode.outTensorIds = {OUTPUT}; + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); + } else { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::PagedAttentionParam opParam; + opParam.headNum = param.head_num; + opParam.qkScale = 1.0f / sqrt(param.head_dim); + opParam.kvHeadNum = param.kv_head_num; + if (param.use_alibi) { + opParam.maskType = + atb::infer::PagedAttentionParam::MaskType::MASK_TYPE_ALIBI; + } else { + opParam.maskType = atb::infer::PagedAttentionParam::MaskType::UNDEFINED; + } + opParam.batchRunStatusEnable = true; + + atb::CreateOperation(opParam, &opNode.operation); + + if (param.use_alibi) { + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, + INPUT_CACHE_K, + INPUT_CACHE_V, + INPUT_BLOCK_TABLES, + INPUT_SEQLEN, + INPUT_MASK, + INPUT_BATCH_STATUS}; + } else { + opNode.inTensorIds = {INTERMEDIATE_EMB_Q, + INPUT_CACHE_K, + INPUT_CACHE_V, + INPUT_BLOCK_TABLES, + INPUT_SEQLEN, + INPUT_BATCH_STATUS}; + } + + opNode.outTensorIds = {OUTPUT}; + opNode.inTensorReshapeFuncs.resize(opNode.inTensorIds.size()); + opNode.inTensorReshapeFuncs[0] = [=](const atb::Dims &oldShape, + atb::Dims &newShape) { + newShape.dimNum = 3; + newShape.dims[0] = oldShape.dims[0]; + newShape.dims[1] = param.head_num; + newShape.dims[2] = param.head_dim; + }; + } + + atb::CreateOperation(opGraph, operation); +} + +} // namespace atb_layers + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.h b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.h new file mode 100644 index 00000000000..7853ce78693 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_fapa_attention.h @@ -0,0 +1,45 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_ATB + +#include "atb/atb_infer.h" + +namespace atb_layers { + +struct FaPaAttentionParam { + int64_t head_num; + int64_t kv_head_num; + int64_t head_dim; + bool use_alibi{false}; + bool rope_neox{false}; + bool is_prefill; +}; + +void CreateFaPaAttention(const FaPaAttentionParam& param, + atb::Operation** operation); + +} // namespace atb_layers + +namespace atb { +template <> +inline Status CreateOperation(const atb_layers::FaPaAttentionParam& opParam, + Operation** operation) { + atb_layers::CreateFaPaAttention(opParam, operation); + return ErrorType::NO_ERROR; +} +} // namespace atb + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.cc b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.cc new file mode 100644 index 00000000000..6485af2dd3a --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_ATB + +#include "fused_rms_norm.h" // NOLINT + +namespace atb_layers { + +void CreateRmsNorm(const RmsNormParam ¶m, atb::Operation **operation) { + uint64_t TENSOR_ID = 0; + uint64_t INPUT = TENSOR_ID++; + uint64_t INPUT_WEIGHT = TENSOR_ID++; + uint64_t INPUT_RESIDUAL = param.has_residual ? TENSOR_ID++ : INPUT_WEIGHT; + uint64_t OUTPUT = TENSOR_ID++; + uint64_t OUTPUT_RESIDUAL = param.has_residual ? TENSOR_ID++ : OUTPUT; + + uint64_t nodeIdx = 0; + atb::GraphParam opGraph; + opGraph.name = "RmsNormOperation"; + opGraph.internalTensorNum = 0; + + if (param.has_residual) { + opGraph.inTensorNum = 3; + opGraph.outTensorNum = 2; + opGraph.nodes.resize(2); + } else { + opGraph.inTensorNum = 2; + opGraph.outTensorNum = 1; + opGraph.nodes.resize(1); + } + + if (param.has_residual) { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::ElewiseParam opParam; + opParam.elewiseType = atb::infer::ElewiseParam::ElewiseType::ELEWISE_ADD; + atb::CreateOperation(opParam, &opNode.operation); + opNode.inTensorIds = {INPUT, INPUT_RESIDUAL}; + opNode.outTensorIds = {OUTPUT_RESIDUAL}; + } + + { + atb::Node &opNode = opGraph.nodes.at(nodeIdx++); + atb::infer::RmsNormParam opParam; + opParam.layerType = atb::infer::RmsNormParam::RmsNormType::RMS_NORM_NORM; + opParam.normParam.epsilon = param.epsilon; + atb::CreateOperation(opParam, &opNode.operation); + if (param.has_residual) { + opNode.inTensorIds = {OUTPUT_RESIDUAL, INPUT_WEIGHT}; + } else { + opNode.inTensorIds = {INPUT, INPUT_WEIGHT}; + } + opNode.outTensorIds = {OUTPUT}; + } + + atb::CreateOperation(opGraph, operation); +} + +} // namespace atb_layers + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.h b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.h new file mode 100644 index 00000000000..de82d24e2e6 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/atb_layers/fused_rms_norm.h @@ -0,0 +1,40 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_ATB + +#include "atb/atb_infer.h" + +namespace atb_layers { + +struct RmsNormParam { + float epsilon{1.0}; + bool has_residual{false}; +}; + +void CreateRmsNorm(const RmsNormParam& param, atb::Operation** operation); + +} // namespace atb_layers + +namespace atb { +template <> +inline Status CreateOperation(const atb_layers::RmsNormParam& opParam, + Operation** operation) { + atb_layers::CreateRmsNorm(opParam, operation); + return ErrorType::NO_ERROR; +} +} // namespace atb + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attention_op.cc b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attention_op.cc new file mode 100644 index 00000000000..f46b168f881 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attention_op.cc @@ -0,0 +1,327 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_ATB +#include "fused_fapa_attn_op_utils.h" // NOLINT + +constexpr int32_t kFusedFaPaAttnLayerBegin = 1; +constexpr int32_t kFusedFaPaAttnLayerEnd = 2; + +static bool first_or_second_flag = false; +static int layer_id = 0; + +void FusedFaPaAttnLayerOpPrefillStage(const phi::CustomContext &dev_ctx, + const paddle::Tensor &qkv_out, + const paddle::Tensor &cos, + const paddle::Tensor &sin, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + int64_t max_seq_len, + int64_t head_num, + int64_t kv_head_num, + int64_t head_dim, + int64_t emb_dim, + int64_t ntokens, + bool use_neox_style, + bool use_alibi) { + int64_t batch_size = + FusedFaPaGlobalVar::Instance().get_seqlens_encoder()->size; + void *slots_data = FusedFaPaGlobalVar::Instance().get_slots_encoder(); + void *seqlens_dev = + FusedFaPaGlobalVar::Instance().get_seqlens_encoder()->dev_ptr; + void *seqlens_host = + FusedFaPaGlobalVar::Instance().get_seqlens_encoder()->host_ptr; + + void *first_out_data = + FusedFaPaGlobalVar::Instance().get_out_encoder()->first->data(); + void *second_out_data = + FusedFaPaGlobalVar::Instance().get_out_encoder()->second->data(); + + auto &runner = *FusedFaPaGlobalVar::Instance().get_encoder_runner(layer_id); + + if (runner.is_initialized()) { + runner.reset_variant_pack(); + } else { + atb_layers::FaPaAttentionParam param; + param.use_alibi = use_alibi; + param.rope_neox = use_neox_style; + param.head_num = head_num; + param.kv_head_num = kv_head_num; + param.head_dim = head_dim; + param.is_prefill = true; + runner.create(param); + } + + runner.bind_input(qkv_out); + + if (!use_alibi) { + runner.bind_input(cos, {ntokens, head_dim}); + runner.bind_input(sin, {ntokens, head_dim}); + } + if (!use_alibi) { + void *mask_data = FusedFaPaGlobalVar::Instance().get_casual_mask(); + runner.bind_input( + mask_data, phi::DataType::BFLOAT16, {max_seq_len, max_seq_len}); + } else { + void *mask_data = FusedFaPaGlobalVar::Instance().get_alibi_src_mask(); + // 1, head_num, max_seq_len, max_seq_len + runner.bind_input(mask_data, + phi::DataType::BFLOAT16, + {batch_size, head_num, max_seq_len, max_seq_len}); + } + + runner.bind_input(cache_k); + runner.bind_input(cache_v); + runner.bind_input(slots_data, phi::DataType::INT32, {ntokens}); + runner.bind_input( + seqlens_dev, seqlens_host, phi::DataType::INT32, {batch_size}); + if (first_or_second_flag) { + runner.bind_output( + first_out_data, phi::DataType::BFLOAT16, {ntokens, emb_dim}); + } else { + runner.bind_output( + second_out_data, phi::DataType::BFLOAT16, {ntokens, emb_dim}); + } + + runner.setup(dev_ctx); + atb_layers::TaskQueue::Instance(dev_ctx.GetPlace().GetDeviceId()) + .Commit(std::move( + std::packaged_task([&] { runner.execute(dev_ctx); }))); +} + +void FusedFaPaAttnLayerOpDecodingStage(const phi::CustomContext &dev_ctx, + const paddle::Tensor &qkv_out, + const paddle::Tensor &cos, + const paddle::Tensor &sin, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &block_tables, + int64_t max_seq_len, + int64_t head_num, + int64_t kv_head_num, + int64_t head_dim, + int64_t emb_dim, + int64_t ntokens, + bool use_neox_style, + bool use_alibi) { + int64_t batch_size = + FusedFaPaGlobalVar::Instance().get_seqlens_decoder()->size; + void *slots_data = FusedFaPaGlobalVar::Instance().get_slots_decoder(); + void *seqlens_dev = + FusedFaPaGlobalVar::Instance().get_seqlens_decoder()->dev_ptr; + void *seqlens_host = + FusedFaPaGlobalVar::Instance().get_seqlens_decoder()->host_ptr; + void *batch_status_data = + FusedFaPaGlobalVar::Instance().get_batch_status()->data; + + void *first_out_data = + FusedFaPaGlobalVar::Instance().get_out_decoder()->first->data(); + void *second_out_data = + FusedFaPaGlobalVar::Instance().get_out_decoder()->second->data(); + + auto &runner = *FusedFaPaGlobalVar::Instance().get_decoder_runner(layer_id); + + if (runner.is_initialized()) { + runner.reset_variant_pack(); + } else { + atb_layers::FaPaAttentionParam param; + param.use_alibi = use_alibi; + param.rope_neox = use_neox_style; + param.head_num = head_num; + param.kv_head_num = kv_head_num; + param.head_dim = head_dim; + param.is_prefill = false; + runner.create(param); + } + + runner.bind_input(qkv_out); + + if (!use_alibi) { + runner.bind_input(cos, {ntokens, head_dim}); + runner.bind_input(sin, {ntokens, head_dim}); + } + if (use_alibi) { + void *mask_data = FusedFaPaGlobalVar::Instance().get_alibi_tgt_mask(); + // batch, head_num, 1, max_seq_len + runner.bind_input(mask_data, + phi::DataType::BFLOAT16, + {batch_size, head_num, 1, max_seq_len}); + } + runner.bind_input(cache_k); + runner.bind_input(cache_v); + runner.bind_input(slots_data, phi::DataType::INT32, {ntokens}); + runner.bind_input(block_tables); + runner.bind_input( + seqlens_dev, seqlens_host, phi::DataType::INT32, {batch_size}); + runner.bind_host_input(batch_status_data, phi::DataType::INT32, {batch_size}); + if (first_or_second_flag) { + runner.bind_output( + first_out_data, phi::DataType::BFLOAT16, {ntokens, emb_dim}); + } else { + runner.bind_output( + second_out_data, phi::DataType::BFLOAT16, {ntokens, emb_dim}); + } + + runner.setup(dev_ctx); + atb_layers::TaskQueue::Instance(dev_ctx.GetPlace().GetDeviceId()) + .Commit(std::move( + std::packaged_task([&] { runner.execute(dev_ctx); }))); +} + +std::vector FusedFaPaAttnOp( + const paddle::Tensor &qkv_out, + const paddle::Tensor &cos, + const paddle::Tensor &sin, + const paddle::Tensor &cache_k, + const paddle::Tensor &cache_v, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, + int32_t head_num, + int32_t kv_head_num, + int32_t head_dim, + int32_t flag, + int32_t max_seq_len, + int32_t block_size, + bool use_neox_style, + bool use_alibi) { + const auto &cache_k_shape = cache_k.shape(); + const auto &block_tables_shape = block_tables.shape(); + uint64_t emb_dim = head_num * head_dim; + uint64_t max_block_num_per_seq = block_tables_shape[1]; + uint64_t batch_size = seq_lens_encoder.numel(); + + auto place = qkv_out.place(); + const auto &dev_ctx = *static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(place)); + if (flag == kFusedFaPaAttnLayerBegin) { + FusedFaPaGlobalVar::Instance().update_block_tables(dev_ctx, block_tables); + FusedFaPaGlobalVar::Instance().update_seqlens_encoder(dev_ctx, + seq_lens_encoder); + FusedFaPaGlobalVar::Instance().update_seqlens_decoder(dev_ctx, + seq_lens_decoder); + + FusedFaPaGlobalVar::Instance().update_casual_mask(dev_ctx, max_seq_len); + + FusedFaPaGlobalVar::Instance().update_slots_encoder( + dev_ctx, block_size, max_block_num_per_seq); + FusedFaPaGlobalVar::Instance().update_slots_decoder( + dev_ctx, block_size, max_block_num_per_seq); + + FusedFaPaGlobalVar::Instance().update_in_encoder(dev_ctx, qkv_out); + FusedFaPaGlobalVar::Instance().update_in_decoder(dev_ctx, qkv_out); + + first_or_second_flag = false; + layer_id = 0; + } else { + first_or_second_flag = !first_or_second_flag; + layer_id++; + } + + auto ntokens_encoder = + FusedFaPaGlobalVar::Instance().get_seqlens_encoder()->ntokens; + auto ntokens_decoder = + FusedFaPaGlobalVar::Instance().get_seqlens_decoder()->ntokens; + + int32_t ntokens = -1; + if (ntokens_encoder > 0) { + FusedFaPaAttnLayerOpPrefillStage(dev_ctx, + qkv_out, + cos, + sin, + cache_k, + cache_v, + max_seq_len, + head_num, + kv_head_num, + head_dim, + emb_dim, + ntokens_encoder, + use_neox_style, + use_alibi); + ntokens = ntokens_encoder; + } + if (ntokens_decoder > 0) { + FusedFaPaAttnLayerOpDecodingStage(dev_ctx, + qkv_out, + cos, + sin, + cache_k, + cache_v, + block_tables, + max_seq_len, + head_num, + kv_head_num, + head_dim, + emb_dim, + ntokens_decoder, + use_neox_style, + use_alibi); + ntokens = ntokens_decoder; + } + + paddle::Tensor out(place); + atb_layers::TaskQueue::Instance(dev_ctx.GetPlace().GetDeviceId()).Wait(); + + fapa_layers::init_tensor( + dev_ctx, phi::DataType::BFLOAT16, {ntokens, emb_dim}, &out); + FusedFaPaGlobalVar::Instance().update_out_encoder( + dev_ctx, first_or_second_flag, &out); + FusedFaPaGlobalVar::Instance().update_out_decoder( + dev_ctx, first_or_second_flag, &out); + return {out}; +} + +std::vector> FusedFaPaAttnOpInferShape( + const std::vector &qkv_out_shape, + const std::vector &cos_shape, + const std::vector &sin_shape, + const std::vector &cache_k_shape, + const std::vector &cache_v_shape, + const std::vector &seq_lens_encoder_shape, + const std::vector &seq_lens_decoder_shape, + const std::vector &block_tables_shape, + int32_t head_num, + int32_t kv_head_num, + int32_t head_dim, + int32_t flag, + int32_t max_seq_len, + int32_t block_size, + bool use_neox_style, + bool use_alibi) { + return {{-1, qkv_out_shape[1]}}; +} + +PD_BUILD_OP(fused_fapa_attention_op) + .Inputs({"qkv_out", + "cos", + "sin", + "cache_k", + "cache_v", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables"}) + .Outputs({"attn_out"}) + .Attrs({"num_heads: int", + "kv_num_heads: int", + "head_dim: int", + "flag: int", // begin: 1, end: 2, other: 0 + "max_seq_len: int", + "block_size: int", + "use_neox_style: bool", + "use_alibi: bool"}) + .SetKernelFn(PD_KERNEL(FusedFaPaAttnOp)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedFaPaAttnOpInferShape)); + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.cc b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.cc new file mode 100644 index 00000000000..ce67283c6f6 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.cc @@ -0,0 +1,676 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_ATB + +#include "fused_fapa_attn_op_utils.h" // NOLINT + +namespace custom_kernel { +template +void FullKernel(const Context &dev_ctx, + const phi::IntArray &shape, + const phi::Scalar &val, + phi::DataType dtype, + phi::DenseTensor *out); + +template +void TrilKernel(const Context &ctx, + const phi::DenseTensor &x, + int diagonal, + phi::DenseTensor *out); + +template +void ScaleKernel(const Context &dev_ctx, + const phi::DenseTensor &x, + const phi::Scalar &in_scale, + const phi::Scalar &in_bias, + bool bias_after_scale, + phi::DenseTensor *out); +} // namespace custom_kernel + +namespace fapa_layers { + +void init_tensor(const phi::CustomContext &dev_ctx, + const phi::DataType &dtype, + const std::vector &shape, + paddle::Tensor *tensor) { + phi::DenseTensorMeta meta(dtype, phi::make_ddim(shape)); + static_cast(tensor->impl().get())->set_meta(meta); + dev_ctx.Alloc(static_cast(tensor->impl().get()), dtype); +} + +} // namespace fapa_layers + +FusedFaPaGlobalVar &FusedFaPaGlobalVar::Instance() { + static FusedFaPaGlobalVar ins; + return ins; +} + +void FusedFaPaGlobalVar::update_seqlens_encoder( + const phi::CustomContext &dev_ctx, const paddle::Tensor &seqlen) { + static void *g_seqlens_encoder_int64 = nullptr; + // init + if (!g_seqlens_encoder_int64) { + g_seqlens_encoder.size = seqlen.numel(); + ACL_CHECK(aclrtMallocHost(&g_seqlens_encoder.host_ptr, + g_seqlens_encoder.size * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(&g_seqlens_encoder_int64, + g_seqlens_encoder.size * sizeof(int64_t))); + g_seqlens_encoder.dev_tensor.Resize({g_seqlens_encoder.size}); + dev_ctx.template Alloc(&g_seqlens_encoder.dev_tensor); + g_seqlens_encoder.dev_ptr = g_seqlens_encoder.dev_tensor.data(); + } + + // update + if (seqlen.dtype() == phi::DataType::INT32) { + ACL_CHECK( + aclrtMemcpyAsync(g_seqlens_encoder.host_ptr, + g_seqlens_encoder.size * sizeof(int32_t), + seqlen.data(), + g_seqlens_encoder.size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + } else { + ACL_CHECK( + aclrtMemcpyAsync(g_seqlens_encoder_int64, + g_seqlens_encoder.size * sizeof(int64_t), + seqlen.data(), + g_seqlens_encoder.size * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + for (auto i = 0; i < g_seqlens_encoder.size; ++i) { + reinterpret_cast(g_seqlens_encoder.host_ptr)[i] = + static_cast( + reinterpret_cast(g_seqlens_encoder_int64)[i]); + } + } + + // calc ntokens + auto *g_seqlens_encoder_host = reinterpret_cast( + reinterpret_cast(g_seqlens_encoder.host_ptr)); + g_seqlens_encoder.ntokens = 0; + for (auto i = 0; i < g_seqlens_encoder.size; ++i) { + if (g_seqlens_encoder_host[i] > 0) { + g_seqlens_encoder.ntokens += g_seqlens_encoder_host[i]; + } + } + + // copy to device + ACL_CHECK(aclrtMemcpyAsync(g_seqlens_encoder.dev_ptr, + g_seqlens_encoder.size * sizeof(int32_t), + g_seqlens_encoder.host_ptr, + g_seqlens_encoder.size * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); +} + +void FusedFaPaGlobalVar::update_seqlens_decoder( + const phi::CustomContext &dev_ctx, const paddle::Tensor &seqlen) { + static void *g_seqlens_decoder_int64 = nullptr; + // init + if (!g_seqlens_decoder_int64) { + g_seqlens_decoder.size = seqlen.numel(); + ACL_CHECK(aclrtMallocHost(&g_seqlens_decoder.host_ptr, + g_seqlens_decoder.size * sizeof(int32_t))); + g_batch_status.size = seqlen.numel(); + ACL_CHECK(aclrtMallocHost(&g_batch_status.data, + g_batch_status.size * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(&g_seqlens_decoder_int64, + g_seqlens_decoder.size * sizeof(int64_t))); + g_seqlens_decoder.dev_tensor.Resize({g_seqlens_decoder.size}); + dev_ctx.template Alloc(&g_seqlens_decoder.dev_tensor); + g_seqlens_decoder.dev_ptr = g_seqlens_decoder.dev_tensor.data(); + } + + // update + if (seqlen.dtype() == phi::DataType::INT32) { + ACL_CHECK( + aclrtMemcpyAsync(g_seqlens_decoder.host_ptr, + g_seqlens_decoder.size * sizeof(int32_t), + seqlen.data(), + g_seqlens_decoder.size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + } else { + ACL_CHECK( + aclrtMemcpyAsync(g_seqlens_decoder_int64, + g_seqlens_decoder.size * sizeof(int64_t), + seqlen.data(), + g_seqlens_decoder.size * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + for (auto i = 0; i < g_seqlens_decoder.size; ++i) { + reinterpret_cast(g_seqlens_decoder.host_ptr)[i] = + static_cast( + reinterpret_cast(g_seqlens_decoder_int64)[i]); + } + } + + // calc ntokens + auto *g_seqlens_decoder_host = reinterpret_cast( + reinterpret_cast(g_seqlens_decoder.host_ptr)); + auto *g_batch_status_host = reinterpret_cast( + reinterpret_cast(g_batch_status.data)); + g_seqlens_decoder.ntokens = 0; + for (auto i = 0; i < g_seqlens_decoder.size; ++i) { + g_batch_status_host[i] = 0; + if (g_seqlens_decoder_host[i] > 0) { + g_seqlens_decoder_host[i] += 1; + g_seqlens_decoder.ntokens += 1; + g_batch_status_host[i] = 1; + } + } + + // copy to device + ACL_CHECK(aclrtMemcpyAsync(g_seqlens_decoder.dev_ptr, + g_seqlens_decoder.size * sizeof(int32_t), + g_seqlens_decoder.host_ptr, + g_seqlens_decoder.size * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); +} + +void FusedFaPaGlobalVar::update_block_tables( + const phi::CustomContext &dev_ctx, const paddle::Tensor &block_tables) { + static void *g_block_tables_int64 = nullptr; + // init + if (!g_block_tables_int64) { + g_block_tables.size = block_tables.numel(); + ACL_CHECK(aclrtMallocHost(&g_block_tables.data, + g_block_tables.size * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(&g_block_tables_int64, + g_block_tables.size * sizeof(int64_t))); + } + + // update + if (block_tables.dtype() == phi::DataType::INT32) { + ACL_CHECK( + aclrtMemcpyAsync(g_block_tables.data, + g_block_tables.size * sizeof(int32_t), + block_tables.data(), + g_block_tables.size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + } else { + ACL_CHECK( + aclrtMemcpyAsync(g_block_tables_int64, + g_block_tables.size * sizeof(int64_t), + block_tables.data(), + g_block_tables.size * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, + reinterpret_cast(dev_ctx.stream()))); + ACL_CHECK(aclrtSynchronizeStream( + reinterpret_cast(dev_ctx.stream()))); + for (auto i = 0; i < g_block_tables.size; ++i) { + reinterpret_cast(g_block_tables.data)[i] = + static_cast( + reinterpret_cast(g_block_tables_int64)[i]); + } + } +} + +void FusedFaPaGlobalVar::update_rope_encoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &rope_emb, + int64_t max_seqlen, + int64_t head_dim) { + if (g_seqlens_encoder.ntokens == 0) { + return; + } + // init + g_rope_emb_encoder.rope_emb_cos = std::make_shared(); + g_rope_emb_encoder.rope_emb_cos->Resize( + {g_seqlens_encoder.ntokens, head_dim}); + dev_ctx.template Alloc(g_rope_emb_encoder.rope_emb_cos.get()); + + g_rope_emb_encoder.rope_emb_sin = std::make_shared(); + g_rope_emb_encoder.rope_emb_sin->Resize( + {g_seqlens_encoder.ntokens, head_dim}); + dev_ctx.template Alloc(g_rope_emb_encoder.rope_emb_sin.get()); + + // update + C_Device_st device{dev_ctx.GetPlace().GetDeviceId()}; + C_Stream stream = + const_cast(reinterpret_cast(dev_ctx.stream())); + + void *new_cos_data = g_rope_emb_encoder.rope_emb_cos->data(); + void *new_sin_data = g_rope_emb_encoder.rope_emb_sin->data(); + void *cos_data = const_cast(rope_emb.data()); + void *sin_data = + cos_data + rope_emb.numel() / 2 * phi::SizeOf(phi::DataType::BFLOAT16); + + uint64_t out_offset = 0; + uint64_t in_offset = 0; + uint64_t numel = 0; + int32_t *seqlens = reinterpret_cast(g_seqlens_encoder.host_ptr); + uint64_t seqlens_size = g_seqlens_encoder.size; + for (auto i = 0; i < seqlens_size; ++i) { + if (seqlens[i] > 0) { + out_offset += numel; + numel = seqlens[i] * head_dim; + AsyncMemCpyD2D( + &device, + stream, + new_cos_data + out_offset * phi::SizeOf(phi::DataType::BFLOAT16), + cos_data + in_offset * phi::SizeOf(phi::DataType::BFLOAT16), + numel * phi::SizeOf(phi::DataType::BFLOAT16)); + AsyncMemCpyD2D( + &device, + stream, + new_sin_data + out_offset * phi::SizeOf(phi::DataType::BFLOAT16), + sin_data + in_offset * phi::SizeOf(phi::DataType::BFLOAT16), + numel * phi::SizeOf(phi::DataType::BFLOAT16)); + } + } +} + +void FusedFaPaGlobalVar::update_rope_decoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &rope_emb, + int64_t max_seqlen, + int64_t head_dim) { + if (g_seqlens_decoder.ntokens == 0) { + return; + } + // init + g_rope_emb_decoder.rope_emb_cos = std::make_shared(); + g_rope_emb_decoder.rope_emb_cos->Resize( + {g_seqlens_decoder.ntokens, head_dim}); + dev_ctx.template Alloc(g_rope_emb_decoder.rope_emb_cos.get()); + + g_rope_emb_decoder.rope_emb_sin = std::make_shared(); + g_rope_emb_decoder.rope_emb_sin->Resize( + {g_seqlens_decoder.ntokens, head_dim}); + dev_ctx.template Alloc(g_rope_emb_decoder.rope_emb_sin.get()); + + // update + C_Device_st device{dev_ctx.GetPlace().GetDeviceId()}; + C_Stream stream = + const_cast(reinterpret_cast(dev_ctx.stream())); + + void *new_cos_data = g_rope_emb_decoder.rope_emb_cos->data(); + void *new_sin_data = g_rope_emb_decoder.rope_emb_sin->data(); + void *cos_data = const_cast(rope_emb.data()); + void *sin_data = + cos_data + rope_emb.numel() / 2 * phi::SizeOf(phi::DataType::BFLOAT16); + + uint64_t out_offset = 0; + uint64_t in_offset = 0; + uint64_t numel = 0; + int32_t *seqlens = reinterpret_cast(g_seqlens_decoder.host_ptr); + uint64_t seqlens_size = g_seqlens_decoder.size; + for (auto i = 0; i < seqlens_size; ++i) { + if (seqlens[i] > 0) { + in_offset = (seqlens[i] - 1) * head_dim; + out_offset += numel; + numel = head_dim; + AsyncMemCpyD2D( + &device, + stream, + new_cos_data + out_offset * phi::SizeOf(phi::DataType::BFLOAT16), + cos_data + in_offset * phi::SizeOf(phi::DataType::BFLOAT16), + numel * phi::SizeOf(phi::DataType::BFLOAT16)); + AsyncMemCpyD2D( + &device, + stream, + new_sin_data + out_offset * phi::SizeOf(phi::DataType::BFLOAT16), + sin_data + in_offset * phi::SizeOf(phi::DataType::BFLOAT16), + numel * phi::SizeOf(phi::DataType::BFLOAT16)); + } + } +} + +void FusedFaPaGlobalVar::update_slots_encoder(const phi::CustomContext &dev_ctx, + int64_t block_size, + int64_t max_block_num) { + if (g_seqlens_encoder.ntokens == 0) { + return; + } + static int32_t *g_slots = nullptr; + static int64_t g_slots_size = 0; + // init + if (g_slots_size < g_seqlens_encoder.ntokens) { + g_slots_size = g_seqlens_encoder.ntokens; + if (g_slots != nullptr) { + ACL_CHECK(aclrtFreeHost(g_slots)); + g_slots = nullptr; + } + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_slots), + g_slots_size * sizeof(int32_t))); + g_slots_encoder = std::make_shared(); + g_slots_encoder->Resize({g_slots_size}); + dev_ctx.template Alloc(g_slots_encoder.get()); + } + + // update + int64_t idx = 0; + int64_t block_offset = 0; + int32_t *block_tables = reinterpret_cast(g_block_tables.data); + int32_t *seqlens = reinterpret_cast(g_seqlens_encoder.host_ptr); + uint64_t seqlens_size = g_seqlens_encoder.size; + for (int64_t i = 0; i < seqlens_size; ++i) { + int64_t len = seqlens[i]; + if (len > 0) { + int64_t need_block_num = len / block_size; + int64_t tail_len = len % block_size; + int64_t slot_offset = 0; + for (int64_t j = 0; j < need_block_num; ++j) { + slot_offset = block_tables[block_offset + j] * block_size; + for (int64_t k = 0; k < block_size; ++k) { + g_slots[idx++] = slot_offset + k; + } + len -= block_size; + } + slot_offset = block_tables[block_offset + need_block_num] * block_size; + for (int64_t k = 0; k < tail_len; ++k) { + g_slots[idx++] = slot_offset + k; + } + } + block_offset += max_block_num; + } + + // copy to device + ACL_CHECK(aclrtMemcpyAsync(g_slots_encoder->data(), + g_seqlens_encoder.ntokens * sizeof(int32_t), + g_slots, + g_seqlens_encoder.ntokens * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); +} + +void FusedFaPaGlobalVar::update_slots_decoder(const phi::CustomContext &dev_ctx, + int64_t block_size, + int64_t max_block_num) { + if (g_seqlens_decoder.ntokens == 0) { + return; + } + static int32_t *g_slots = nullptr; + static int64_t g_slots_size = 0; + // init + if (g_slots_size < g_seqlens_decoder.ntokens) { + g_slots_size = g_seqlens_decoder.ntokens; + if (g_slots != nullptr) { + ACL_CHECK(aclrtFreeHost(g_slots)); + g_slots = nullptr; + } + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_slots), + g_slots_size * sizeof(int32_t))); + g_slots_decoder = std::make_shared(); + g_slots_decoder->Resize({g_slots_size}); + dev_ctx.template Alloc(g_slots_decoder.get()); + } + + // update + int64_t idx = 0; + int64_t block_offset = 0; + int32_t *block_tables = reinterpret_cast(g_block_tables.data); + int32_t *seqlens = reinterpret_cast(g_seqlens_decoder.host_ptr); + uint64_t seqlens_size = g_seqlens_decoder.size; + for (int64_t i = 0; i < seqlens_size; i++) { + int64_t len = seqlens[i]; + if (len > 0) { + int64_t need_block_num = (len - 1) / block_size; + int64_t tail_len = (len - 1) % block_size; + int64_t slot_offset = + block_tables[block_offset + need_block_num] * block_size; + g_slots[idx++] = slot_offset + tail_len; + } + block_offset += max_block_num; + } + + // copy to device + ACL_CHECK(aclrtMemcpyAsync(g_slots_decoder->data(), + g_seqlens_decoder.ntokens * sizeof(int32_t), + g_slots, + g_seqlens_decoder.ntokens * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); +} + +void FusedFaPaGlobalVar::update_casual_mask(const phi::CustomContext &dev_ctx, + uint64_t max_seq_len) { + if (!g_mask.get()) { + g_mask = std::make_shared(); + } + if (g_mask->numel() != max_seq_len * max_seq_len) { + LOG(INFO) << "update_mask: max_seq_len=" << max_seq_len; + g_mask->Resize({max_seq_len, max_seq_len}); + dev_ctx.template Alloc(g_mask.get()); + + phi::DenseTensor ones_tensor; + custom_kernel::FullKernel(dev_ctx, + {max_seq_len, max_seq_len}, + 1.0f, + phi::DataType::BFLOAT16, + &ones_tensor); + + phi::DenseTensor tril_ones_tensor; + tril_ones_tensor.Resize({max_seq_len, max_seq_len}); + custom_kernel::TrilKernel( + dev_ctx, ones_tensor, 0, &tril_ones_tensor); + + phi::DenseTensor tmp_mask; + tmp_mask.Resize({max_seq_len, max_seq_len}); + custom_kernel::ScaleKernel( + dev_ctx, tril_ones_tensor, -1.0f, 1.0f, true, g_mask.get()); + } +} + +void FusedFaPaGlobalVar::update_in_encoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &hidden) { + if (g_seqlens_encoder.ntokens == 0) { + return; + } + auto hidden_shape = hidden.shape(); + auto ntokens = g_seqlens_encoder.ntokens; + auto emb_dim = hidden_shape[1]; + // init + g_out_encoder.first = std::make_shared(); + g_out_encoder.first->Resize({ntokens, emb_dim}); + dev_ctx.template Alloc(g_out_encoder.first.get()); + + g_out_encoder.second = std::make_shared(); + g_out_encoder.second->Resize({ntokens, emb_dim}); + dev_ctx.template Alloc(g_out_encoder.second.get()); + + // udpate + void *in_data = const_cast(hidden.data()); + void *out_data = g_out_encoder.first->data(); + int32_t batch_size = g_seqlens_encoder.size; + int32_t *seqlens_encoder = + reinterpret_cast(g_seqlens_encoder.host_ptr); + int32_t *seqlens_decoder = + reinterpret_cast(g_seqlens_decoder.host_ptr); + + int64_t in_offset = 0, out_offset = 0, numel = 0; + for (auto i = 0; i < batch_size; ++i) { + if (seqlens_encoder[i] > 0) { + numel = seqlens_encoder[i] * emb_dim; + ACL_CHECK( + aclrtMemcpyAsync(out_data + out_offset * sizeof(phi::bfloat16), + numel * sizeof(phi::bfloat16), + in_data + in_offset * sizeof(phi::bfloat16), + numel * sizeof(phi::bfloat16), + ACL_MEMCPY_DEVICE_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); + in_offset += numel; + out_offset += numel; + } else if (seqlens_decoder[i] > 0) { + in_offset += emb_dim; + } + } +} + +void FusedFaPaGlobalVar::update_in_decoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &hidden) { + if (g_seqlens_decoder.ntokens == 0) { + return; + } + auto hidden_shape = hidden.shape(); + auto ntokens = g_seqlens_decoder.ntokens; + auto emb_dim = hidden_shape[1]; + // init + g_out_decoder.first = std::make_shared(); + g_out_decoder.first->Resize({ntokens, emb_dim}); + dev_ctx.template Alloc(g_out_decoder.first.get()); + + g_out_decoder.second = std::make_shared(); + g_out_decoder.second->Resize({ntokens, emb_dim}); + dev_ctx.template Alloc(g_out_decoder.second.get()); + + // udpate + void *in_data = const_cast(hidden.data()); + void *out_data = g_out_decoder.first->data(); + int32_t batch_size = g_seqlens_decoder.size; + int32_t *seqlens_encoder = + reinterpret_cast(g_seqlens_encoder.host_ptr); + int32_t *seqlens_decoder = + reinterpret_cast(g_seqlens_decoder.host_ptr); + + int64_t in_offset = 0, out_offset = 0, numel = 0; + for (auto i = 0; i < batch_size; ++i) { + if (seqlens_encoder[i] > 0) { + in_offset += seqlens_encoder[i] * emb_dim; + } else if (seqlens_decoder[i] > 0) { + numel = emb_dim; + ACL_CHECK( + aclrtMemcpyAsync(out_data + out_offset * sizeof(phi::bfloat16), + numel * sizeof(phi::bfloat16), + in_data + in_offset * sizeof(phi::bfloat16), + numel * sizeof(phi::bfloat16), + ACL_MEMCPY_DEVICE_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); + in_offset += emb_dim; + out_offset += emb_dim; + } + } +} + +void FusedFaPaGlobalVar::update_out_encoder(const phi::CustomContext &dev_ctx, + bool first_or_second, + paddle::Tensor *out) { + if (g_seqlens_encoder.ntokens == 0) { + return; + } + auto out_shape = out->shape(); + auto emb_dim = out_shape[1]; + auto ntokens = out_shape[0]; + + // udpate + void *in_data = first_or_second ? g_out_encoder.first->data() + : g_out_encoder.second->data(); + void *out_data = out->data(); + int32_t batch_size = g_seqlens_encoder.size; + int32_t *seqlens_encoder = + reinterpret_cast(g_seqlens_encoder.host_ptr); + int32_t *seqlens_decoder = + reinterpret_cast(g_seqlens_decoder.host_ptr); + + int64_t in_offset = 0, out_offset = 0; + for (auto i = 0; i < ntokens; ++i) { + out_offset = i * emb_dim; + ACL_CHECK( + aclrtMemcpyAsync(out_data + out_offset * sizeof(phi::bfloat16), + emb_dim * sizeof(phi::bfloat16), + in_data + out_offset * sizeof(phi::bfloat16), + emb_dim * sizeof(phi::bfloat16), + ACL_MEMCPY_DEVICE_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); + } +} + +void FusedFaPaGlobalVar::update_out_decoder(const phi::CustomContext &dev_ctx, + bool first_or_second, + paddle::Tensor *out) { + if (g_seqlens_decoder.ntokens == 0) { + return; + } + auto out_shape = out->shape(); + auto emb_dim = out_shape[1]; + + // udpate + void *in_data = first_or_second ? g_out_decoder.first->data() + : g_out_decoder.second->data(); + void *out_data = out->data(); + int32_t batch_size = g_seqlens_encoder.size; + int32_t *seqlens_encoder = + reinterpret_cast(g_seqlens_encoder.host_ptr); + int32_t *seqlens_decoder = + reinterpret_cast(g_seqlens_decoder.host_ptr); + + int64_t in_offset = 0, out_offset = 0; + for (auto i = 0; i < batch_size; ++i) { + if (seqlens_decoder[i] > 0) { + out_offset = i * emb_dim; + ACL_CHECK( + aclrtMemcpyAsync(out_data + out_offset * sizeof(phi::bfloat16), + emb_dim * sizeof(phi::bfloat16), + in_data + in_offset * sizeof(phi::bfloat16), + emb_dim * sizeof(phi::bfloat16), + ACL_MEMCPY_DEVICE_TO_DEVICE, + reinterpret_cast(dev_ctx.stream()))); + in_offset += emb_dim; + } + } +} + +void FusedFaPaGlobalVar::update_qkv_deq_offset( + const phi::CustomContext &dev_ctx, int64_t sz) { + if (!g_qkv_deq_offset.get()) { + g_qkv_deq_offset = std::make_shared(); + g_qkv_deq_offset->Resize({sz}); + custom_kernel::FullKernel( + dev_ctx, {sz}, 0, phi::DataType::INT32, g_qkv_deq_offset.get()); + } +} + +void FusedFaPaGlobalVar::update_out_deq_offset( + const phi::CustomContext &dev_ctx, int64_t sz) { + if (!g_out_deq_offset.get()) { + g_out_deq_offset = std::make_shared(); + g_out_deq_offset->Resize({sz}); + custom_kernel::FullKernel( + dev_ctx, {sz}, 0, phi::DataType::INT32, g_out_deq_offset.get()); + } +} + +void FusedFaPaGlobalVar::update_ffn1_deq_offset( + const phi::CustomContext &dev_ctx, int64_t sz) { + if (!g_ffn1_deq_offset.get()) { + g_ffn1_deq_offset = std::make_shared(); + g_ffn1_deq_offset->Resize({sz}); + custom_kernel::FullKernel( + dev_ctx, {sz}, 0, phi::DataType::INT32, g_ffn1_deq_offset.get()); + } +} + +void FusedFaPaGlobalVar::update_ffn2_deq_offset( + const phi::CustomContext &dev_ctx, int64_t sz) { + if (!g_ffn2_deq_offset.get()) { + g_ffn2_deq_offset = std::make_shared(); + g_ffn2_deq_offset->Resize({sz}); + custom_kernel::FullKernel( + dev_ctx, {sz}, 0, phi::DataType::INT32, g_ffn2_deq_offset.get()); + } +} + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.h b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.h new file mode 100644 index 00000000000..da2287157eb --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/fused_fapa_attn_op_utils.h @@ -0,0 +1,195 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#ifdef PADDLE_WITH_ATB + +#include +#include + +#include "atb_layers/fused_fapa_attention.h" +#include "atb_layers/runner.h" +#include "kernels/funcs/npu_funcs.h" +#include "kernels/funcs/npu_op_runner.h" +#include "paddle/extension.h" + +namespace fapa_layers { +void init_tensor(const phi::CustomContext &dev_ctx, + const phi::DataType &dtype, + const std::vector &shape, + paddle::Tensor *tensor); +} // namespace fapa_layers + +std::shared_ptr get_triu_mask( + const phi::CustomContext &dev_ctx, uint64_t max_seq_len); + +class FusedFaPaGlobalVar { + public: + struct SeqPack { + void *dev_ptr = nullptr; + void *host_ptr = nullptr; + uint64_t ntokens = 0; + uint64_t size = 0; + phi::DenseTensor dev_tensor; + }; + + struct RopePack { + std::shared_ptr rope_emb_cos; + std::shared_ptr rope_emb_sin; + }; + + struct OutPack { + std::shared_ptr first; + std::shared_ptr second; + }; + + struct Pack { + void *data = nullptr; + uint64_t size = 0; + }; + + static FusedFaPaGlobalVar &Instance(); + + SeqPack *get_seqlens_encoder() { return &g_seqlens_encoder; } + + SeqPack *get_seqlens_decoder() { return &g_seqlens_decoder; } + + Pack *get_block_tables() { return &g_block_tables; } + + Pack *get_batch_status() { return &g_batch_status; } + + RopePack *get_rope_encoder() { return &g_rope_emb_encoder; } + + RopePack *get_rope_decoder() { return &g_rope_emb_decoder; } + + void *get_slots_encoder() { return g_slots_encoder->data(); } + + void *get_slots_decoder() { return g_slots_decoder->data(); } + + void *get_casual_mask() { return g_mask->data(); } + + void *get_alibi_src_mask() { return alibi_src_mask; } + void *get_alibi_tgt_mask() { return alibi_tgt_mask; } + + OutPack *get_out_encoder() { return &g_out_encoder; } + OutPack *get_out_decoder() { return &g_out_decoder; } + + void *get_qkv_deq_offset() { return g_qkv_deq_offset->data(); } + + void *get_out_deq_offset() { return g_out_deq_offset->data(); } + + void *get_ffn1_deq_offset() { return g_ffn1_deq_offset->data(); } + + void *get_ffn2_deq_offset() { return g_ffn2_deq_offset->data(); } + + atb_layers::OperationRunner *get_encoder_runner(int64_t idx) { + return &g_encoder_runners[idx]; + } + + atb_layers::OperationRunner *get_decoder_runner(int64_t idx) { + return &g_decoder_runners[idx]; + } + + // async d2h + sync + async h2d + void update_seqlens_encoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &seqlen); + + // async d2h + sync + async h2d + void update_seqlens_decoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &seqlen); + + // async d2h + sync + void update_block_tables(const phi::CustomContext &dev_ctx, + const paddle::Tensor &block_tables); + + // async d2d + void update_rope_encoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &rope_emb, + int64_t max_seqlen, + int64_t head_dim); + + // async d2d + void update_rope_decoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &rope_emb, + int64_t max_seqlen, + int64_t head_dim); + + // async h2d + void update_slots_encoder(const phi::CustomContext &dev_ctx, + int64_t block_size, + int64_t max_block_num); + + // async h2d + void update_slots_decoder(const phi::CustomContext &dev_ctx, + int64_t block_size, + int64_t max_block_num); + + // async d2d + void update_casual_mask(const phi::CustomContext &dev_ctx, + uint64_t max_seq_len); + + void update_alibi_mask(void *src_mask, void *tgt_mask) { + alibi_src_mask = src_mask; + alibi_tgt_mask = tgt_mask; + } + + // async d2d + void update_in_encoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &hidden); + + // async d2d + void update_in_decoder(const phi::CustomContext &dev_ctx, + const paddle::Tensor &hidden); + + // async d2d + void update_out_encoder(const phi::CustomContext &dev_ctx, + bool, + paddle::Tensor *out); + + // async d2d + void update_out_decoder(const phi::CustomContext &dev_ctx, + bool, + paddle::Tensor *out); + + void update_qkv_deq_offset(const phi::CustomContext &dev_ctx, int64_t sz); + + void update_out_deq_offset(const phi::CustomContext &dev_ctx, int64_t sz); + + void update_ffn1_deq_offset(const phi::CustomContext &dev_ctx, int64_t sz); + + void update_ffn2_deq_offset(const phi::CustomContext &dev_ctx, int64_t sz); + + private: + SeqPack g_seqlens_encoder; + SeqPack g_seqlens_decoder; + RopePack g_rope_emb_encoder; + RopePack g_rope_emb_decoder; + Pack g_batch_status; + Pack g_block_tables; + std::shared_ptr g_slots_encoder{nullptr}; + std::shared_ptr g_slots_decoder{nullptr}; + std::shared_ptr g_mask{nullptr}; + OutPack g_out_encoder; + OutPack g_out_decoder; + std::shared_ptr g_qkv_deq_offset{nullptr}; + std::shared_ptr g_out_deq_offset{nullptr}; + std::shared_ptr g_ffn1_deq_offset{nullptr}; + std::shared_ptr g_ffn2_deq_offset{nullptr}; + std::unordered_map g_encoder_runners; + std::unordered_map g_decoder_runners; + void *alibi_src_mask; + void *alibi_tgt_mask; +}; + +#endif diff --git a/backends/npu/custom_op/llama_infer/atb_ops/fused_rms_norm_op.cc b/backends/npu/custom_op/llama_infer/atb_ops/fused_rms_norm_op.cc new file mode 100644 index 00000000000..66ce5b0528b --- /dev/null +++ b/backends/npu/custom_op/llama_infer/atb_ops/fused_rms_norm_op.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_ATB + +#include "atb_layers/fused_rms_norm.h" +#include "fused_blha_layer_op_utils.h" // NOLINT + +std::vector> RmsNormShape( + const std::vector& x, + const std::vector& norm_weight, + const paddle::optional>& residual, + float epsilon) { + std::vector out_dims = x; + + return {out_dims}; +} + +static atb_layers::OperationRunner g_RmsNormRunner; + +std::vector RmsNormOp( + const paddle::Tensor& x, + const paddle::Tensor& norm_weight, + const paddle::optional& residual, + float epsilon) { + auto place = x.place(); + const auto& dev_ctx = *static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(place)); + + auto out_dtype = x.dtype(); + std::vector out_shape = x.shape(); + + paddle::Tensor out(place); + init_tensor(dev_ctx, out_dtype, out_shape, &out); + paddle::Tensor residual_out(place); + init_tensor(dev_ctx, out_dtype, out_shape, &residual_out); + + if (g_RmsNormRunner.is_initialized()) { + g_RmsNormRunner.reset_variant_pack(); + } + + atb_layers::RmsNormParam param; + param.epsilon = epsilon; + param.has_residual = residual.is_initialized(); + g_RmsNormRunner.create(param); + + g_RmsNormRunner.bind_input(x); + g_RmsNormRunner.bind_input(norm_weight); + if (residual.is_initialized()) { + g_RmsNormRunner.bind_input(residual.get()); + } + + g_RmsNormRunner.bind_output(&out); + if (residual.is_initialized()) { + g_RmsNormRunner.bind_output(&residual_out); + } + g_RmsNormRunner.run(dev_ctx); + return {out, residual_out}; +} + +PD_BUILD_OP(atb_rms_norm) // atb_flash_attention rms_norm + .Inputs({"x", "norm_weight", "residual@OPTIONAL"}) // tensor + .Outputs({"out", "residual_out"}) // tensor + .Attrs({ + "epsilon: float", // int/float/bool + }) + .SetKernelFn(PD_KERNEL(RmsNormOp)) // 适配 + .SetInferShapeFn(PD_INFER_SHAPE(RmsNormShape)); // shape校验 +// .SetInferDtypeFn(PD_INFER_DTYPE(RmsNormDType)); // type校验 + +#endif diff --git a/backends/npu/custom_op/llama_infer/fused_step_op.cc b/backends/npu/custom_op/llama_infer/fused_step_op.cc new file mode 100644 index 00000000000..b3f41c71862 --- /dev/null +++ b/backends/npu/custom_op/llama_infer/fused_step_op.cc @@ -0,0 +1,685 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// #include "atb_op_runner/deprecated/atb_op_runner.h" +// #include "atb_op_runner/deprecated/set_value_by_scalar.h" +#include "glog/logging.h" +#include "paddle/extension.h" +#include "runtime/runtime.h" +// FLAGS_DEFINE_bool(npu_step_paddle_debug, true, ""); + +bool in_need_block_list(int seq_id, + int32_t* need_block_list, + int32_t need_block_len) { + bool res = false; + for (int i = 0; i < need_block_len; ++i) { + if (seq_id == need_block_list[i]) { + res = true; + break; + } + } + return res; +} + +void free_block(const aclrtStream& stream, + uint8_t* stop_flags_cpu, + uint8_t* is_block_step_cpu, + int32_t* seq_lens_this_time_cpu, + int32_t* seq_lens_encoder_cpu, + int32_t* seq_lens_decoder_cpu, + int32_t* block_tables_cpu, + int32_t* block_tables_npu, + + int32_t* encoder_block_lens, + int32_t* used_list_len, + int32_t* free_list, + int32_t* free_list_len, + int32_t* need_block_list, + int32_t* need_block_list_len, + int64_t* first_token_ids_cpu, + int64_t max_batch_size, + int64_t block_num_per_seq, + int64_t block_size, + bool step_max_block_flag, + int32_t in_need_block_list_len) { + LOG(INFO) << "<><><><><>beging free_block"; + for (int i = 0; i < max_batch_size; ++i) { + if (i == 0) { + step_max_block_flag = false; + in_need_block_list_len = 0; + } + bool stop_flag = stop_flags_cpu[i]; + bool is_block_step = is_block_step_cpu[i]; + + if (stop_flags_cpu[i] && !is_block_step_cpu[i]) { + int encoder_block_len = encoder_block_lens[i]; // FIXME: 这里 + int decoder_used_len = used_list_len[i]; + if (decoder_used_len > 0) { + const int ori_free_list_len = free_list_len[0]; + free_list_len[0] += decoder_used_len; + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "free block seq_id: " << i + << ", free block num: " << decoder_used_len + << ", encoder_block_len: " << encoder_block_len + << ", ori_free_list_len: " << ori_free_list_len; + // } + for (int j = 0; j < decoder_used_len; ++j) { + free_list[ori_free_list_len + j] = + block_tables_cpu[i * block_num_per_seq + encoder_block_len + j]; + block_tables_cpu[i * block_num_per_seq + encoder_block_len + j] = -1; + } + encoder_block_lens[i] = 0; + used_list_len[i] = 0; + ACL_CHECK(aclrtMemcpyAsync( + &block_tables_npu[i * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + &block_tables_cpu[i * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + } + } else if (seq_lens_decoder_cpu[i] != 0 && + block_tables_cpu[i * block_num_per_seq + + (seq_lens_decoder_cpu[i] + 1) / block_size] == + -1) { + need_block_list[need_block_list_len[0]++] = i; + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "seq_id: " << i << " need block"; + // } + } + } +} + +void dispatch_block(const aclrtStream& stream, + uint8_t* stop_flags_cpu, + uint8_t* stop_flags_npu, + uint8_t* is_block_step_cpu, + uint8_t* is_block_step_npu, + int32_t* seq_lens_this_time_cpu, + int32_t* seq_lens_this_time_npu, + int32_t* seq_lens_encoder_cpu, + int32_t* seq_lens_encoder_npu, + int32_t* seq_lens_decoder_cpu, + int32_t* seq_lens_decoder_npu, + int32_t* block_tables_cpu, + int32_t* block_tables_npu, + + int32_t* encoder_block_lens, + int32_t* used_list_len, + int32_t* free_list, + int32_t* free_list_len, + int32_t* need_block_list, + int32_t* need_block_list_len, + int32_t* step_block_list, + int32_t* step_block_list_len, + int64_t* first_token_ids, + int64_t max_batch_size, + int64_t block_num_per_seq, + int64_t block_size, + bool step_max_block_flag, + int32_t in_need_block_list_len, + int64_t max_decoder_block_num) { + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "need_block_list_len: " << need_block_list_len[0] + << " free_list_len: " << free_list_len[0]; + // } + + // 调度block,根据used_list_len从大到小回收block,直到满足need_block_len, + // 已解码到最后一个block的query不发生调度 + while (need_block_list_len[0] > free_list_len[0]) { + int seq_id = 0; + int decoder_used_len = 0; + for (int i = 0; i < max_batch_size; ++i) { + int used_block_num = + (!is_block_step_cpu[i] && + (step_max_block_flag || used_list_len[i] != max_decoder_block_num)) + ? used_list_len[i] + : 0; + if (used_block_num > decoder_used_len) { + seq_id = i; + decoder_used_len = used_block_num; + } + } + if (decoder_used_len == 0) { + step_max_block_flag = true; + } else { + int encoder_block_len = encoder_block_lens[seq_id]; + + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "max_id: " << seq_id << ", max_num: " << decoder_used_len + << ", encoder_block_len: " << encoder_block_len; + // } + // decoder_used_len > 0 + for (int i = 0; i < decoder_used_len; ++i) { + free_list[free_list_len[0] + i] = + block_tables_cpu[seq_id * block_num_per_seq + encoder_block_len + + i]; + block_tables_cpu[seq_id * block_num_per_seq + encoder_block_len + i] = + -1; + } + ACL_CHECK(aclrtMemcpyAsync( + &block_tables_npu[seq_id * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + &block_tables_cpu[seq_id * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + step_block_list[step_block_list_len[0]] = seq_id; + if (in_need_block_list(seq_id, + need_block_list, + need_block_list_len[0] + in_need_block_list_len)) { + need_block_list_len[0] -= 1; + in_need_block_list_len += 1; + need_block_list[seq_id] = -1; + } + step_block_list_len[0] += 1; + free_list_len[0] += decoder_used_len; + stop_flags_cpu[seq_id] = true; + is_block_step_cpu[seq_id] = true; + seq_lens_this_time_cpu[seq_id] = 0; + seq_lens_decoder_cpu[seq_id] = 0; + seq_lens_encoder_cpu[seq_id] = 0; + ACL_CHECK(aclrtMemcpyAsync(&stop_flags_npu[seq_id], + 1 * sizeof(uint8_t), + &stop_flags_cpu[seq_id], + 1 * sizeof(uint8_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&is_block_step_npu[seq_id], + 1 * sizeof(uint8_t), + &is_block_step_cpu[seq_id], + 1 * sizeof(uint8_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&seq_lens_this_time_npu[seq_id], + 1 * sizeof(int32_t), + &seq_lens_this_time_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&seq_lens_decoder_npu[seq_id], + 1 * sizeof(int32_t), + &seq_lens_decoder_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&seq_lens_encoder_npu[seq_id], + 1 * sizeof(int32_t), + &seq_lens_encoder_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + } + } + // 为需要block的位置分配block,每个位置分配一个block + for (auto i = 0; i < (need_block_list_len[0] + in_need_block_list_len); ++i) { + if (need_block_list[i] != -1) { + auto seq_id = need_block_list[i]; + if (!stop_flags_cpu[seq_id]) { + used_list_len[seq_id] += 1; + auto ori_free_list_len = free_list_len[0]; + free_list_len[0]--; + auto block_offset = (seq_lens_decoder_cpu[seq_id] + 1) / block_size; + block_tables_cpu[seq_id * block_num_per_seq + block_offset] = + free_list[ori_free_list_len - 1]; + ACL_CHECK(aclrtMemcpyAsync( + &block_tables_npu[seq_id * block_num_per_seq + block_offset], + 1 * sizeof(int32_t), + &block_tables_cpu[seq_id * block_num_per_seq + block_offset], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + } + need_block_list[i] = -1; + } + } + need_block_list_len[0] = 0; +} + +void recover_block(const aclrtStream& stream, + uint8_t* stop_flags_cpu, + uint8_t* stop_flags_npu, + uint8_t* is_block_step_cpu, + uint8_t* is_block_step_npu, + int32_t* seq_lens_this_time_cpu, + int32_t* seq_lens_this_time_npu, + int32_t* seq_lens_encoder_cpu, + int32_t* seq_lens_encoder_npu, + int32_t* seq_lens_decoder_cpu, + int32_t* seq_lens_decoder_npu, + int32_t* block_tables_cpu, + int32_t* block_tables_npu, + int64_t* input_ids_npu, + int64_t* pre_ids_npu, + int64_t* step_idx_cpu, + int64_t* next_tokens_npu, + int64_t* first_tokens_ids_cpu, + int32_t* recover_block_list, + int32_t* recover_len, + int32_t* encoder_block_lens, + int32_t* used_list_len, + int32_t* free_list, + int32_t* free_list_len, + int32_t* need_block_list, + int32_t* need_block_list_len, + int32_t* step_block_list, + int32_t* step_block_list_len, + int32_t* ori_seq_lens_encoder, + + int64_t max_batch_size, + int64_t max_seq_len, + int64_t pre_id_length, + int64_t block_num_per_seq, + int64_t block_size, + int64_t max_block_num) { + // 计算可以复原的query id + int ori_step_len = step_block_list_len[0]; + if (ori_step_len > 0) { + int ori_free_list_len = free_list_len[0]; + int ori_step_block_id = step_block_list[ori_step_len - 1]; + int tmp_used_len = used_list_len[ori_step_block_id]; + const int max_decoder_block_num_this_seq = + max_block_num - encoder_block_lens[ori_step_block_id]; + // 比之前调度时多分配一个block,防止马上恢复刚调度的query(比如回收的seq_id在need_block_list中) + int used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + if (ori_step_len > 0 && ori_free_list_len >= used_len) { + // NPU 一次只复原一条数据,否则encoder warmup会失效 + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "recover seq_id:" << ori_step_block_id + << " , free_list_len: " << ori_free_list_len + << ", used_list_len: " << used_len; + // } + recover_block_list[recover_len[0]] = ori_step_block_id; + is_block_step_cpu[ori_step_block_id] = false; + ACL_CHECK(aclrtMemcpyAsync(&is_block_step_npu[ori_step_block_id], + 1 * sizeof(uint8_t), + &is_block_step_cpu[ori_step_block_id], + 1 * sizeof(uint8_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + used_list_len[ori_step_block_id] = used_len; + ori_free_list_len -= used_len; + step_block_list[ori_step_len - 1] = -1; + step_block_list_len[0] -= 1; + recover_len[0] += 1; + ori_step_len = step_block_list_len[0]; + if (ori_step_len > 0) { + ori_step_block_id = step_block_list[ori_step_len - 1]; + tmp_used_len = used_list_len[ori_step_block_id]; + used_len = tmp_used_len + 1 < max_decoder_block_num_this_seq + ? tmp_used_len + 1 + : max_decoder_block_num_this_seq; + } + } + } + // if (recover_len[0] > 0 && FLAGS_npu_step_paddle_debug) { + VLOG(0) << "recover_len: " << recover_len[0]; + // } + for (int i = 0; i < recover_len[0]; ++i) { + auto seq_id = recover_block_list[i]; + auto ori_seq_len_encoder = ori_seq_lens_encoder[seq_id]; + auto step_id = step_idx_cpu[seq_id]; + auto seq_len = ori_seq_len_encoder + step_id; + auto encoder_block_len = encoder_block_lens[seq_id]; + auto decoder_used_len = used_list_len[seq_id]; + + seq_lens_this_time_cpu[seq_id] = seq_len; + seq_lens_encoder_cpu[seq_id] = seq_len; + stop_flags_cpu[seq_id] = false; + ACL_CHECK(aclrtMemcpyAsync(&seq_lens_this_time_npu[seq_id], + 1 * sizeof(int32_t), + &seq_lens_this_time_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&seq_lens_encoder_npu[seq_id], + 1 * sizeof(int32_t), + &seq_lens_encoder_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&stop_flags_npu[seq_id], + 1 * sizeof(int32_t), + &stop_flags_cpu[seq_id], + 1 * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + auto ori_free_list_len = free_list_len[0]; + free_list_len[0] -= used_list_len[seq_id]; + // if (FLAGS_npu_step_paddle_debug) { + VLOG(0) << "seq_id: " << seq_id + << ", ori_seq_len_encoder: " << ori_seq_len_encoder + << ", step_idx_now: " << step_id << ", seq_len: " << seq_len + << ", decoder_used_len: " << decoder_used_len + << ", ori_free_list_len_tid0: " << ori_free_list_len + << ", free_list_len: " << free_list_len[0]; + // } + std::memcpy( + &block_tables_cpu[seq_id * block_num_per_seq + encoder_block_len], + &free_list[ori_free_list_len - decoder_used_len], + decoder_used_len * sizeof(int32_t)); + ACL_CHECK(aclrtMemcpyAsync( + &block_tables_npu[seq_id * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + &block_tables_cpu[seq_id * block_num_per_seq + encoder_block_len], + decoder_used_len * sizeof(int32_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK( + aclrtMemcpyAsync(&input_ids_npu[seq_id * max_seq_len + seq_len - 1], + 1 * sizeof(int64_t), + &next_tokens_npu[seq_id], + 1 * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync(&input_ids_npu[seq_id * max_seq_len], + 1 * sizeof(int64_t), + &first_tokens_ids_cpu[seq_id], + 1 * sizeof(int64_t), + ACL_MEMCPY_HOST_TO_DEVICE, + stream)); + ACL_CHECK(aclrtMemcpyAsync( + &input_ids_npu[seq_id * max_seq_len + ori_seq_len_encoder], + (step_id - 1) * sizeof(int64_t), + &pre_ids_npu[seq_id * pre_id_length + 1], + (step_id - 1) * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_DEVICE, + stream)); + } + recover_len[0] = 0; +} + +static uint8_t* g_stop_flags_cpu = nullptr; +static uint8_t* g_is_block_step_cpu = nullptr; +static int32_t* g_seq_lens_this_time_cpu = nullptr; +static int32_t* g_seq_lens_encoder_cpu = nullptr; +static int32_t* g_seq_lens_decoder_cpu = nullptr; +static int32_t* g_block_tables_cpu = nullptr; +static int64_t* g_step_idx_cpu = nullptr; + +void AtbStepPaddle( + const paddle::Tensor& stop_flags, // [mbs, 1] + const paddle::Tensor& seq_lens_this_time, // [mbs, 1] + const paddle::Tensor& ori_seq_lens_encoder, // cpu + const paddle::Tensor& seq_lens_encoder, // [mbs, 1] + const paddle::Tensor& seq_lens_decoder, // [mbs, 1] + const paddle::Tensor& block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor& encoder_block_lens, // cpu + const paddle::Tensor& is_block_step, // [mbs, 1] + const paddle::Tensor& step_block_list, // cpu + const paddle::Tensor& step_lens, // cpu + const paddle::Tensor& recover_block_list, // cpu + const paddle::Tensor& recover_lens, // cpu + const paddle::Tensor& need_block_list, // cpu + const paddle::Tensor& need_block_list_len, // cpu + const paddle::Tensor& used_list_len, // cpu + const paddle::Tensor& free_list, // cpu + const paddle::Tensor& free_list_len, // cpu + const paddle::Tensor& input_ids, + const paddle::Tensor& pre_ids, + const paddle::Tensor& step_idx, // [mbs, 1] + const paddle::Tensor& next_tokens, // [mbs, 1] + const paddle::Tensor& first_token_ids, // cpu, [mbs, 1] + const int block_size, + const int encoder_decoder_block_num) { + auto start_point = std::chrono::high_resolution_clock::now(); + + auto place = input_ids.place(); + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(place)); + aclrtStream stream = reinterpret_cast(dev_ctx->stream()); + + const int max_batch_size = block_tables.shape()[0]; + const int block_num_per_seq = block_tables.shape()[1]; + const int max_seq_len = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + const int max_block_num = max_seq_len / block_size; + + if (!g_stop_flags_cpu) { + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_stop_flags_cpu), + max_batch_size * sizeof(uint8_t))); + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_is_block_step_cpu), + max_batch_size * sizeof(uint8_t))); + ACL_CHECK( + aclrtMallocHost(reinterpret_cast(&g_seq_lens_this_time_cpu), + max_batch_size * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_seq_lens_encoder_cpu), + max_batch_size * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_seq_lens_decoder_cpu), + max_batch_size * sizeof(int32_t))); + ACL_CHECK( + aclrtMallocHost(reinterpret_cast(&g_block_tables_cpu), + max_batch_size * block_num_per_seq * sizeof(int32_t))); + ACL_CHECK(aclrtMallocHost(reinterpret_cast(&g_step_idx_cpu), + max_batch_size * sizeof(int64_t))); + } + + ACL_CHECK(aclrtMemcpyAsync(g_stop_flags_cpu, + max_batch_size * sizeof(uint8_t), + stop_flags.data(), + max_batch_size * sizeof(uint8_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + ACL_CHECK(aclrtMemcpyAsync(g_is_block_step_cpu, + max_batch_size * sizeof(uint8_t), + is_block_step.data(), + max_batch_size * sizeof(uint8_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + ACL_CHECK(aclrtMemcpyAsync(g_seq_lens_this_time_cpu, + max_batch_size * sizeof(int32_t), + seq_lens_this_time.data(), + max_batch_size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + ACL_CHECK(aclrtMemcpyAsync(g_seq_lens_encoder_cpu, + max_batch_size * sizeof(int32_t), + seq_lens_encoder.data(), + max_batch_size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + ACL_CHECK(aclrtMemcpyAsync(g_seq_lens_decoder_cpu, + max_batch_size * sizeof(int32_t), + seq_lens_decoder.data(), + max_batch_size * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + ACL_CHECK( + aclrtMemcpyAsync(g_block_tables_cpu, + max_batch_size * block_num_per_seq * sizeof(int32_t), + block_tables.data(), + max_batch_size * block_num_per_seq * sizeof(int32_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + dev_ctx->Wait(); + + bool step_max_block_flag = false; + int32_t in_need_block_list_len = 0; + + LOG(INFO) << "<><><><><>first_token_ids.data()" << first_token_ids.data(); + LOG(INFO) << "<><><><><>reinterpret_cast(const_cast(first_" + "token_ids.data()))<><><><><>" + << reinterpret_cast( + const_cast(first_token_ids.data())); + + LOG(INFO) << "<><><><><>before free block"; + free_block( + stream, + g_stop_flags_cpu, + g_is_block_step_cpu, + g_seq_lens_this_time_cpu, + g_seq_lens_encoder_cpu, + g_seq_lens_decoder_cpu, + g_block_tables_cpu, + reinterpret_cast(const_cast(block_tables.data())), + reinterpret_cast(const_cast(encoder_block_lens.data())), + reinterpret_cast(const_cast(used_list_len.data())), + reinterpret_cast(const_cast(free_list.data())), + reinterpret_cast(const_cast(free_list_len.data())), + reinterpret_cast(const_cast(need_block_list.data())), + reinterpret_cast(const_cast(need_block_list_len.data())), + reinterpret_cast(const_cast(first_token_ids.data())), + max_batch_size, + block_num_per_seq, + block_size, + step_max_block_flag, + in_need_block_list_len); + LOG(INFO) << "<><><><><>after free block"; + + dispatch_block( + stream, + g_stop_flags_cpu, + reinterpret_cast(const_cast(stop_flags.data())), + g_is_block_step_cpu, + reinterpret_cast(const_cast(is_block_step.data())), + g_seq_lens_this_time_cpu, + reinterpret_cast(const_cast(seq_lens_this_time.data())), + g_seq_lens_encoder_cpu, + reinterpret_cast(const_cast(seq_lens_encoder.data())), + g_seq_lens_decoder_cpu, + reinterpret_cast(const_cast(seq_lens_decoder.data())), + g_block_tables_cpu, + reinterpret_cast(const_cast(block_tables.data())), + reinterpret_cast(const_cast(encoder_block_lens.data())), + reinterpret_cast(const_cast(used_list_len.data())), + reinterpret_cast(const_cast(free_list.data())), + reinterpret_cast(const_cast(free_list_len.data())), + reinterpret_cast(const_cast(need_block_list.data())), + reinterpret_cast(const_cast(need_block_list_len.data())), + reinterpret_cast(const_cast(step_block_list.data())), + reinterpret_cast(const_cast(step_lens.data())), + reinterpret_cast(const_cast(first_token_ids.data())), + max_batch_size, + block_num_per_seq, + block_size, + step_max_block_flag, + in_need_block_list_len, + max_decoder_block_num); + + ACL_CHECK(aclrtMemcpyAsync(g_step_idx_cpu, + max_batch_size * sizeof(int64_t), + step_idx.data(), + max_batch_size * sizeof(int64_t), + ACL_MEMCPY_DEVICE_TO_HOST, + stream)); + dev_ctx->Wait(); + recover_block( + stream, + g_stop_flags_cpu, + reinterpret_cast(const_cast(stop_flags.data())), + g_is_block_step_cpu, + reinterpret_cast(const_cast(is_block_step.data())), + g_seq_lens_this_time_cpu, + reinterpret_cast(const_cast(seq_lens_this_time.data())), + g_seq_lens_encoder_cpu, + reinterpret_cast(const_cast(seq_lens_encoder.data())), + g_seq_lens_decoder_cpu, + reinterpret_cast(const_cast(seq_lens_decoder.data())), + g_block_tables_cpu, + reinterpret_cast(const_cast(block_tables.data())), + reinterpret_cast(const_cast(input_ids.data())), + reinterpret_cast(const_cast(pre_ids.data())), + g_step_idx_cpu, + reinterpret_cast(const_cast(next_tokens.data())), + reinterpret_cast(const_cast(first_token_ids.data())), + reinterpret_cast(const_cast(recover_block_list.data())), + reinterpret_cast(const_cast(recover_lens.data())), + reinterpret_cast(const_cast(encoder_block_lens.data())), + reinterpret_cast(const_cast(used_list_len.data())), + reinterpret_cast(const_cast(free_list.data())), + reinterpret_cast(const_cast(free_list_len.data())), + reinterpret_cast(const_cast(need_block_list.data())), + reinterpret_cast(const_cast(need_block_list_len.data())), + reinterpret_cast(const_cast(step_block_list.data())), + reinterpret_cast(const_cast(step_lens.data())), + reinterpret_cast( + const_cast(ori_seq_lens_encoder.data())), + max_batch_size, + max_seq_len, + pre_id_length, + block_num_per_seq, + block_size, + max_block_num); + + dev_ctx->Wait(); + + auto end_point = std::chrono::high_resolution_clock::now(); +} + +PD_BUILD_OP(step_paddle_op) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_tokens_ids"}) + .Attrs({"block_size: int", "encoder_decoder_block_num: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, // npu + {"seq_lens_this_time", "seq_lens_this_time_out"}, // npu + {"seq_lens_encoder", "seq_lens_encoder_out"}, // npu + {"seq_lens_decoder", "seq_lens_decoder_out"}, // npu + {"block_tables", "block_tables_out"}, // npu + {"encoder_block_lens", "encoder_block_lens_out"}, // cpu + {"is_block_step", "is_block_step_out"}, // npu + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_tokens_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(AtbStepPaddle)); diff --git a/backends/npu/custom_op/llama_infer/fused_weight_only_linear.cc b/backends/npu/custom_op/llama_infer/fused_weight_only_linear.cc new file mode 100644 index 00000000000..fbddbbc853d --- /dev/null +++ b/backends/npu/custom_op/llama_infer/fused_weight_only_linear.cc @@ -0,0 +1,82 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "kernels/funcs/npu_op_runner.h" +#include "paddle/extension.h" + +std::vector> WeightOnlyLinearInferShape( + const std::vector& x_shape, + const std::vector& weight_shape, + const std::vector& scale_shape) { + std::vector output_shape; + output_shape.push_back(x_shape[0]); + output_shape.push_back(weight_shape[1]); + return {x_shape}; +} + +std::vector weight_only_linear_npu( + const paddle::Tensor& x, + const paddle::Tensor& weight, + const paddle::Tensor& scale) { + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(x.place())); + + auto x_tensor = static_cast(x.impl().get()); + auto weight_tensor = + static_cast(weight.impl().get()); + auto scale_tensor = static_cast(scale.impl().get()); + + std::vector out_dims = {x_tensor->dims()[0], + weight_tensor->dims()[1]}; + // auto x_dims = x_tensor->dims(); + // for (int i = 0; i < x_dims.size()-1; i++) { + // out_dims.push_back(x_dims[i]); + // } + // out_dims.push_back(weight_tensor->dims()[1]); + + std::shared_ptr out_tensor = + std::make_shared(); + out_tensor->Resize(phi::make_ddim(out_dims)); + dev_ctx->Alloc(out_tensor.get(), x_tensor->dtype()); + + phi::DenseTensor* null1 = nullptr; + phi::DenseTensor* null2 = nullptr; + phi::DenseTensor* null3 = nullptr; + phi::DenseTensor* null4 = nullptr; + + int64_t zero = 0; + + EXEC_NPU_CMD(aclnnWeightQuantBatchMatmulV2, + *dev_ctx, + *x_tensor, + *weight_tensor, + *scale_tensor, + null1, + null2, + null3, + null4, + zero, + *out_tensor); + return {paddle::Tensor(out_tensor)}; +} + +PD_BUILD_OP(weight_only_linear_npu) + .Inputs({"x", "weight", "scale"}) + .Outputs({"output"}) + .SetKernelFn(PD_KERNEL(weight_only_linear_npu)) + .SetInferShapeFn(PD_INFER_SHAPE( + WeightOnlyLinearInferShape)); // neccessary if the op has muti_inputs