diff --git a/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc b/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc index 40d342635d8..25be553c918 100644 --- a/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc +++ b/backends/intel_hpu/custom_ops/llama_infer/fused_block_attention.cc @@ -34,12 +34,41 @@ struct FusedBlockAttentionParams { bool use_neox_style = true; bool with_qkv_biases = false; bool transpose = true; + bool use_fp8 = false; }; -class FusedMHABlockAttention : public HpuFusedOperator { +class FusedBlockAttentionBase : public HpuFusedOperator { + public: + explicit FusedBlockAttentionBase(const std::string& guid, bool is_eager) + : HpuFusedOperator(guid, is_eager) {} + + template + inline void AddNodeMixedPrecisionGemm(bool use_fp8, + ConvertTensors& ct, + int scale_x_index, + int scale_y_index, + std::vector inputs, + std::vector outputs, + synGEMMParams gemm_params, + const std::string& suffix) { + if (use_fp8) { + synTensor scale_x = createTensorFromCT(&ct, scale_x_index); + synTensor scale_y = createTensorFromCT(&ct, scale_x_index + 1); + inputs.push_back(scale_x); + inputs.push_back(scale_y); + AddNodeFusedFP8Gemm( + inputs, outputs, gemm_params, guid_ + "fused_fp8_gemm_" + suffix); + } else { + AddNodeBatchGemm( + inputs, outputs, gemm_params, guid_ + "batchgemm_" + suffix); + } + } +}; + +class FusedMHABlockAttention : public FusedBlockAttentionBase { public: explicit FusedMHABlockAttention(std::string guid_prefix, synDataType dtype) - : HpuFusedOperator(guid_prefix, false), dtype_(dtype) {} + : FusedBlockAttentionBase(guid_prefix, false), dtype_(dtype) {} template void AddNode(ConvertTensors& ct, FusedBlockAttentionParams& params) { auto ins = ct.GetTensors(); @@ -58,7 +87,22 @@ class FusedMHABlockAttention : public HpuFusedOperator { int block_offsets_index = (index_base++); // 9 int qkv_weights_index = (index_base++); // 10 int linear_weights_index = (index_base++); // 11 - int qkv_biases_index = (index_base++); // 12 + + int qk_scale_x_index = -1, qk_scale_y_index = -1, av_scale_x_index = -1, + av_scale_y_index = -1, o_linear_scale_x_index = -1, + o_linear_scale_y_index = -1, qkv_biases_index = -1; + if (params.use_fp8) { + qk_scale_x_index = (index_base++); + qk_scale_y_index = (index_base++); + av_scale_x_index = (index_base++); + av_scale_y_index = (index_base++); + o_linear_scale_x_index = (index_base++); + o_linear_scale_y_index = (index_base++); + } + + if (params.with_qkv_biases) { + qkv_biases_index = (index_base++); + } std::vector src_dims = std::vector(ins[src_index].dims); @@ -432,7 +476,15 @@ class FusedMHABlockAttention : public HpuFusedOperator { std::vector q_k_out; q_k_out.push_back(q_k); - AddNodeBatchGemm(q_k_in, q_k_out, gemm_params_f_t, guid_ + "batchgemm_q_k"); + // Q*k^T + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + qk_scale_x_index, + qk_scale_y_index, + q_k_in, + q_k_out, + gemm_params_f_t, + "q_k"); /*******************************/ @@ -579,8 +631,15 @@ class FusedMHABlockAttention : public HpuFusedOperator { std::vector score_v_out; score_v_out.push_back(score_v); - AddNodeBatchGemm( - score_v_in, score_v_out, gemm_params_f_f, guid_ + "batchgemm_score_v"); + // Score*V + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + av_scale_x_index, + av_scale_y_index, + score_v_in, + score_v_out, + gemm_params_f_f, + "score_v"); auto reduceSum = createTensorNoPresist("reduceSum", dtype_, block_max_dims); std::vector reduceSum_out; @@ -727,18 +786,25 @@ class FusedMHABlockAttention : public HpuFusedOperator { std::vector proj_out; proj_out.push_back(linear_out); - AddNodeBatchGemm( - proj_in, proj_out, gemm_params_f_f, guid_ + "batchgemm_proj"); + // Final linear + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + o_linear_scale_x_index, + o_linear_scale_y_index, + proj_in, + proj_out, + gemm_params_f_f, + "proj"); } protected: synDataType dtype_; }; -class FusedGQABlockAttention : public HpuFusedOperator { +class FusedGQABlockAttention : public FusedBlockAttentionBase { public: explicit FusedGQABlockAttention(std::string guid_prefix, synDataType dtype) - : HpuFusedOperator(guid_prefix, false), dtype_(dtype) {} + : FusedBlockAttentionBase(guid_prefix, false), dtype_(dtype) {} template void AddNode(ConvertTensors& ct, FusedBlockAttentionParams& params) { auto ins = ct.GetTensors(); @@ -757,7 +823,22 @@ class FusedGQABlockAttention : public HpuFusedOperator { int block_offsets_index = (index_base++); // 9 int qkv_weights_index = (index_base++); // 10 int linear_weights_index = (index_base++); // 11 - int qkv_biases_index = (index_base++); // 12 + + int qk_scale_x_index = -1, qk_scale_y_index = -1, av_scale_x_index = -1, + av_scale_y_index = -1, o_linear_scale_x_index = -1, + o_linear_scale_y_index = -1, qkv_biases_index = -1; + if (params.use_fp8) { + qk_scale_x_index = (index_base++); + qk_scale_y_index = (index_base++); + av_scale_x_index = (index_base++); + av_scale_y_index = (index_base++); + o_linear_scale_x_index = (index_base++); + o_linear_scale_y_index = (index_base++); + } + + if (params.with_qkv_biases) { + qkv_biases_index = (index_base++); + } std::vector src_dims = std::vector(ins[src_index].dims); @@ -1157,7 +1238,15 @@ class FusedGQABlockAttention : public HpuFusedOperator { std::vector q_k_out; q_k_out.push_back(q_k); - AddNodeBatchGemm(q_k_in, q_k_out, gemm_params_f_t, guid_ + "batchgemm_q_k"); + // Q*K^T + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + qk_scale_x_index, + qk_scale_y_index, + q_k_in, + q_k_out, + gemm_params_f_t, + "q_k"); /*******************************/ @@ -1308,8 +1397,15 @@ class FusedGQABlockAttention : public HpuFusedOperator { std::vector score_v_out; score_v_out.push_back(score_v); - AddNodeBatchGemm( - score_v_in, score_v_out, gemm_params_f_f, guid_ + "batchgemm_score_v"); + // Score*V + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + av_scale_x_index, + av_scale_y_index, + score_v_in, + score_v_out, + gemm_params_f_f, + "a_v"); auto reduceSum = createTensorNoPresist("reduceSum", dtype_, block_max_dims); std::vector reduceSum_out; @@ -1472,8 +1568,15 @@ class FusedGQABlockAttention : public HpuFusedOperator { std::vector proj_out; proj_out.push_back(linear_out); - AddNodeBatchGemm( - proj_in, proj_out, gemm_params_f_f, guid_ + "batchgemm_proj"); + // Final Linear + AddNodeMixedPrecisionGemm(params.use_fp8, + ct, + o_linear_scale_x_index, + o_linear_scale_y_index, + proj_in, + proj_out, + gemm_params_f_f, + "proj"); } protected: @@ -1496,6 +1599,12 @@ void FusedBlockAttentionKernel( const phi::DenseTensor& qkv_weights, const paddle::optional& qkv_biases, const phi::DenseTensor& linear_weights, + const paddle::optional& qk_scale_x, + const paddle::optional& qk_scale_y, + const paddle::optional& av_scale_x, + const paddle::optional& av_scale_y, + const paddle::optional& o_linear_scale_x, + const paddle::optional& o_linear_scale_y, phi::DenseTensor* out_linear, const phi::Scalar& head_dim, const phi::Scalar& num_head, @@ -1534,6 +1643,26 @@ void FusedBlockAttentionKernel( ct.Add(value_cache, false); std::string guid_prefix = "fused_block_attention_"; + + bool use_fp8 = false; + if (qk_scale_x || qk_scale_y || av_scale_x || av_scale_y || + o_linear_scale_x || o_linear_scale_y) { + if (!qk_scale_x || !qk_scale_y || !av_scale_x || !av_scale_y || + !o_linear_scale_x || !o_linear_scale_y) { + throw std::runtime_error( + "Please specify all scale values for FusedBlockAttentionKernel"); + } + + use_fp8 = true; + guid_prefix = "fused_fp8_block_attention_"; + ct.Add(qk_scale_x.get()); + ct.Add(qk_scale_y.get()); + ct.Add(av_scale_x.get()); + ct.Add(av_scale_y.get()); + ct.Add(o_linear_scale_x.get()); + ct.Add(o_linear_scale_y.get()); + } + if (qkv_biases) { ct.Add(qkv_biases.get()); guid_prefix += "bias_"; @@ -1565,6 +1694,7 @@ void FusedBlockAttentionKernel( params.head_dim = head_dim_; params.num_head = num_head_; params.num_kv_head = num_kv_head; + params.use_fp8 = use_fp8; if (qkv_biases) { params.with_qkv_biases = true; } @@ -1607,6 +1737,12 @@ void CallFusedBlockAttentionKernel( const phi::DenseTensor& qkv_weights, const paddle::optional& qkv_biases, const phi::DenseTensor& linear_weights, + const paddle::optional& qk_scale_x, + const paddle::optional& qk_scale_y, + const paddle::optional& av_scale_x, + const paddle::optional& av_scale_y, + const paddle::optional& o_linear_scale_x, + const paddle::optional& o_linear_scale_y, phi::DenseTensor* out_linear, const phi::Scalar& head_dim, const phi::Scalar& num_head, @@ -1629,6 +1765,12 @@ void CallFusedBlockAttentionKernel( qkv_weights, qkv_biases, linear_weights, + qk_scale_x, + qk_scale_y, + av_scale_x, + av_scale_y, + o_linear_scale_x, + o_linear_scale_y, out_linear, head_dim, num_head, @@ -1651,6 +1793,12 @@ void CallFusedBlockAttentionKernel( qkv_weights, qkv_biases, linear_weights, + qk_scale_x, + qk_scale_y, + av_scale_x, + av_scale_y, + o_linear_scale_x, + o_linear_scale_y, out_linear, head_dim, num_head, @@ -1738,6 +1886,12 @@ std::vector FusedBlockAttentionForward( *qkv_weights_tensor, qkv_biases_tensor, *linear_weights_tensor, + paddle::optional(), + paddle::optional(), + paddle::optional(), + paddle::optional(), + paddle::optional(), + paddle::optional(), out_linear.get(), phi::Scalar(head_dim), phi::Scalar(num_head), @@ -1811,3 +1965,142 @@ PD_BUILD_OP(fused_block_attention) .SetKernelFn(PD_KERNEL(FusedBlockAttentionForward)) .SetInferShapeFn(PD_INFER_SHAPE(FusedBlockAttentionShape)) .SetInferDtypeFn(PD_INFER_DTYPE(FusedBlockAttentionDtype)); + +std::vector FusedFp8BlockAttentionForward( + const paddle::Tensor& src, + const paddle::Tensor& rotary_embs, + const paddle::Tensor& key_cache, + const paddle::Tensor& value_cache, + const paddle::Tensor& block_groups, + const paddle::Tensor& block_list, + const paddle::Tensor& block_mapping, + const paddle::Tensor& block_bias, + const paddle::Tensor& block_indices, + const paddle::Tensor& block_offsets, + const paddle::Tensor& qkv_weights, + const paddle::optional& qkv_biases, + const paddle::Tensor& linear_weights, + const paddle::Tensor& qk_scale_x, + const paddle::Tensor& qk_scale_y, + const paddle::Tensor& av_scale_x, + const paddle::Tensor& av_scale_y, + const paddle::Tensor& o_linear_scale_x, + const paddle::Tensor& o_linear_scale_y, + int head_dim, + int num_head, + float scaling_factor, + bool transpose, + bool use_neox_style) { + auto dev_ctx = static_cast( + paddle::experimental::DeviceContextPool::Instance().Get(src.place())); + auto src_tensor = static_cast(src.impl().get()); + auto rotary_embs_tensor = + static_cast(rotary_embs.impl().get()); + auto key_cache_tensor = + static_cast(key_cache.impl().get()); + auto value_cache_tensor = + static_cast(value_cache.impl().get()); + auto block_groups_tensor = + static_cast(block_groups.impl().get()); + auto block_list_tensor = + static_cast(block_list.impl().get()); + auto block_mapping_tensor = + static_cast(block_mapping.impl().get()); + auto block_bias_tensor = + static_cast(block_bias.impl().get()); + auto block_indices_tensor = + static_cast(block_indices.impl().get()); + auto block_offsets_tensor = + static_cast(block_offsets.impl().get()); + auto qkv_weights_tensor = + static_cast(qkv_weights.impl().get()); + auto linear_weights_tensor = + static_cast(linear_weights.impl().get()); + + auto qkv_biases_tensor = paddle::optional(); + if (qkv_biases) { + auto qkv_biases_dt = + static_cast(qkv_biases->impl().get()); + qkv_biases_tensor = paddle::optional(*qkv_biases_dt); + } + + auto qk_scale_x_tensor = + static_cast(qk_scale_x.impl().get()); + auto qk_scale_y_tensor = + static_cast(qk_scale_y.impl().get()); + auto av_scale_x_tensor = + static_cast(av_scale_x.impl().get()); + auto av_scale_y_tensor = + static_cast(av_scale_y.impl().get()); + auto o_linear_scale_x_tensor = + static_cast(o_linear_scale_x.impl().get()); + auto o_linear_scale_y_tensor = + static_cast(o_linear_scale_y.impl().get()); + + // allocate memory on device. + int64_t batch_size = src.dims()[0]; + int64_t out_features = linear_weights.dims()[1]; + + std::shared_ptr out_linear = + std::make_shared(); + out_linear->Resize(phi::make_ddim({batch_size, out_features})); + dev_ctx->Alloc(out_linear.get(), src_tensor->dtype()); + + CallFusedBlockAttentionKernel(*dev_ctx, + *src_tensor, + *rotary_embs_tensor, + *key_cache_tensor, + *value_cache_tensor, + *block_groups_tensor, + *block_list_tensor, + *block_mapping_tensor, + *block_bias_tensor, + *block_indices_tensor, + *block_offsets_tensor, + *qkv_weights_tensor, + qkv_biases_tensor, + *linear_weights_tensor, + *qk_scale_x_tensor, + *qk_scale_y_tensor, + *av_scale_x_tensor, + *av_scale_y_tensor, + *o_linear_scale_x_tensor, + *o_linear_scale_y_tensor, + out_linear.get(), + phi::Scalar(head_dim), + phi::Scalar(num_head), + phi::Scalar(scaling_factor), + phi::Scalar(transpose), + phi::Scalar(use_neox_style)); + return {paddle::Tensor(out_linear)}; +} + +PD_BUILD_OP(fused_fp8_block_attention) + .Inputs({"src", + "rotary_embs", + "key_cache", + "value_cache", + "block_groups", + "block_list", + "block_mapping", + "block_bias", + "block_indices", + "block_offsets", + "qkv_weights", + paddle::Optional("qkv_biases"), + "linear_weights", + "qk_scale_x", + "qk_scale_y", + "av_scale_x", + "av_scale_y", + "o_linear_scale_x", + "o_linear_scale_y"}) + .Outputs({"out_linear"}) + .Attrs({"head_dim: int", + "num_head: int", + "scaling_factor: float", + "transpose: bool", + "use_neox_style: bool"}) + .SetKernelFn(PD_KERNEL(FusedFp8BlockAttentionForward)) + .SetInferShapeFn(PD_INFER_SHAPE(FusedBlockAttentionShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(FusedBlockAttentionDtype)); diff --git a/backends/intel_hpu/kernels/hpu_funcs.h b/backends/intel_hpu/kernels/hpu_funcs.h index 5e372ba6155..6dbb48c1df2 100644 --- a/backends/intel_hpu/kernels/hpu_funcs.h +++ b/backends/intel_hpu/kernels/hpu_funcs.h @@ -573,9 +573,45 @@ class HpuFusedOperator : public HpuOperator { gemm_ins.push_back(y_tensor); if (!cast_x) { gemm_ins.push_back(inputs[2]); + } else { + synTensor one_tensor = + cloneTensor(node_name + "_one_x", inputs[2], syn_type_float); + ns_ConstantKernel::Params const_params; + const_params.constant.f = 1.0; + std::vector one; + one.push_back(one_tensor); + AddNodeFull(one, const_params, node_name + "_full_x_one"); + + std::vector div_in; + div_in.push_back(one_tensor); + div_in.push_back(inputs[2]); + synTensor d_scale_x_tensor = + cloneTensor(node_name + "_d_scale_x", inputs[2], syn_type_float); + std::vector div_out; + div_out.push_back(d_scale_x_tensor); + AddNodeDivide(div_in, div_out, node_name + "_div_scale_x"); + gemm_ins.push_back(d_scale_x_tensor); } if (!cast_y) { gemm_ins.push_back(inputs[3]); + } else { + synTensor one_tensor = + cloneTensor(node_name + "_one_y", inputs[3], syn_type_float); + ns_ConstantKernel::Params const_params; + const_params.constant.f = 1.0; + std::vector one; + one.push_back(one_tensor); + AddNodeFull(one, const_params, node_name + "_full_y_one"); + + std::vector div_in; + div_in.push_back(one_tensor); + div_in.push_back(inputs[3]); + synTensor d_scale_y_tensor = + cloneTensor(node_name + "_d_scale_y", inputs[3], syn_type_float); + std::vector div_out; + div_out.push_back(d_scale_y_tensor); + AddNodeDivide(div_in, div_out, node_name + "_div_scale_y"); + gemm_ins.push_back(d_scale_y_tensor); } AddNodeFP8Gemm(gemm_ins, outputs, params, node_name); } diff --git a/backends/intel_hpu/tests/unittests/test_fused_fp8_block_attention.py b/backends/intel_hpu/tests/unittests/test_fused_fp8_block_attention.py new file mode 100644 index 00000000000..445f6c8857e --- /dev/null +++ b/backends/intel_hpu/tests/unittests/test_fused_fp8_block_attention.py @@ -0,0 +1,332 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import paddle +import paddlenlp_ops +import os + +intel_hpus_module_id = os.environ.get("FLAGS_selected_intel_hpus", 0) +paddle.device.set_device(f"intel_hpu:{intel_hpus_module_id}") + +seed = 102 +paddle.seed(seed) +np.random.seed(seed) + + +class TestFusedBlockAttention: + def __init__(self): + self.head_dim = 128 + self.num_head = 32 + self.hidden_size = self.num_head * self.head_dim + + self.epsilon = 1e-06 + + self.use_neox = True + self.position_offset = 0 + self.rope_theta = 10000 + + def init_decode_MHA_params(self): + self.test_name = "Test_MHA_FusedBlockAttentionDecode" + self.kv_num_heads = 32 + self.kv_hidden_size = self.head_dim * self.kv_num_heads + self.qkv_biases = None + + self.batch_size = 16 + self.seq_len = 1 + self.block_size = 128 + self.num_of_block = 32 + self.total_block_num = 20 + position_id = paddle.to_tensor([80]) + self.position_ids = paddle.expand( + position_id, shape=[self.batch_size, self.seq_len] + ) + + def init_decode_GQA_params(self): + self.test_name = "Test_GQA_FusedBlockAttentionDecode" + self.kv_num_heads = 4 + self.kv_hidden_size = self.head_dim * self.kv_num_heads + self.qkv_biases = None + + self.batch_size = 16 + self.seq_len = 1 + self.block_size = 128 + self.num_of_block = 32 + self.total_block_num = 20 + position_id = paddle.to_tensor([80]) + self.position_ids = paddle.expand( + position_id, shape=[self.batch_size, self.seq_len] + ) + + def create_tensors(self): + device = paddle.get_device() + + np_k_cache = np.random.rand( + self.total_block_num, self.block_size, self.kv_num_heads, self.head_dim + ).astype("float32") + self.k_cache = ( + paddle.to_tensor(np_k_cache, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + self.k_cache_test = self.k_cache.clone() + + np_v_cache = np.random.rand( + self.total_block_num, self.block_size, self.kv_num_heads, self.head_dim + ).astype("float32") + self.v_cache = ( + paddle.to_tensor(np_v_cache, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + self.v_cache_test = self.v_cache.clone() + + self.input_ids = paddle.zeros( + [self.batch_size, self.seq_len], dtype=paddle.bfloat16 + ) + + np_src = np.random.rand(self.batch_size, self.seq_len, self.hidden_size).astype( + "float32" + ) + self.src = ( + paddle.to_tensor(np_src, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + np_residual = np.random.rand( + self.batch_size, self.seq_len, self.hidden_size + ).astype("float32") + self.residual = ( + paddle.to_tensor(np_residual, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + self.residual_test = self.residual.clone() + + np_ln_scales = np.random.rand(self.hidden_size).astype("float32") + self.ln_scales = ( + paddle.to_tensor(np_ln_scales, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + np_qkv_weights = np.random.rand( + self.hidden_size + 2 * self.kv_hidden_size, self.hidden_size + ).astype("float32") + self.qkv_weights = ( + paddle.to_tensor(np_qkv_weights, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + if self.qkv_biases is not None: + np_qkv_biases = np.random.rand( + self.hidden_size + 2 * self.kv_hidden_size + ).astype("float32") + self.qkv_biases = ( + paddle.to_tensor(np_qkv_biases, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + np_linear_weights = np.random.rand(self.hidden_size, self.hidden_size).astype( + "float32" + ) + self.linear_weights = ( + paddle.to_tensor(np_linear_weights, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + self.head_dim_shape_tensor = paddle.ones(self.head_dim, dtype="int8") + self.new_rope = paddlenlp_ops.fused_get_rotary_embedding( + self.input_ids, + self.position_ids, + self.head_dim_shape_tensor, + self.position_offset, + self.rope_theta, + self.use_neox, + ).to(paddle.bfloat16) + + self.block_indices = paddle.randint( + 0, + self.total_block_num, + [ + self.batch_size, + ], + dtype=paddle.int32, + ) + self.block_offsets = paddle.randint( + 0, + self.block_size, + [ + self.batch_size, + ], + dtype=paddle.int32, + ) + + self.block_groups = paddle.randint( + 0, + self.batch_size, + [ + self.num_of_block, + ], + dtype=paddle.int32, + ) + self.block_list = paddle.randint( + 0, + self.num_of_block, + [ + self.num_of_block, + ], + dtype=paddle.int32, + ) + self.block_mapping = paddle.randint( + 0, 2, [self.num_of_block, self.batch_size], dtype=paddle.int32 + ).to(paddle.bfloat16) + + np_block_bias = np.random.rand(self.num_of_block, self.block_size).astype( + "float32" + ) + self.block_bias = ( + paddle.to_tensor(np_block_bias, place=paddle.CPUPlace()) + .to(paddle.bfloat16) + .to(device) + ) + + self.qk_scale_x = paddle.to_tensor([0.002]).to(device) + self.qk_scale_y = paddle.to_tensor([0.002]).to(device) + self.av_scale_x = paddle.to_tensor([0.1]).to(device) + self.av_scale_y = paddle.to_tensor([0.1]).to(device) + self.o_linear_scale_x = paddle.to_tensor([1.0]).to(device) + self.o_linear_scale_y = paddle.to_tensor([1.0]).to(device) + + def run_test(self): + query_states, key_value_states = paddlenlp_ops.fused_rms_qkv_rope_t( + self.src, + self.ln_scales, + self.qkv_weights, + self.qkv_biases, + self.new_rope.transpose([0, 1, 3, 2, 4]), + self.residual, + self.epsilon, + self.head_dim, + self.num_head, + ) + key_states = key_value_states[0].squeeze(1) + value_states = key_value_states[1].squeeze(1) + + self.k_cache.index_put_((self.block_indices, self.block_offsets), key_states) + self.v_cache.index_put_((self.block_indices, self.block_offsets), value_states) + + out_linear_out_ref = paddlenlp_ops.fused_flatpa_proj( + query_states, + self.k_cache, + self.v_cache, + self.block_groups, + self.block_list, + self.block_mapping, + self.block_bias, + self.linear_weights, + scaling_factor=self.head_dim**-0.5, + ) + + src, self.residual_test = paddle.incubate.nn.functional.fused_rms_norm( + self.src, + norm_weight=self.ln_scales, + norm_bias=None, + epsilon=self.epsilon, + begin_norm_axis=2, + bias=None, + residual=self.residual_test, + ) + + b, s, h = src.shape + src = src.reshape([-1, h]) + out_linear_out = paddlenlp_ops.fused_fp8_block_attention( + src, + self.new_rope.transpose([0, 1, 3, 2, 4]).squeeze(2), + self.k_cache_test, + self.v_cache_test, + self.block_groups, + self.block_list, + self.block_mapping, + self.block_bias, + self.block_indices, + self.block_offsets, + self.qkv_weights, + self.qkv_biases, + self.linear_weights, + self.qk_scale_x, + self.qk_scale_y, + self.av_scale_x, + self.av_scale_y, + self.o_linear_scale_x, + self.o_linear_scale_y, + self.head_dim, + self.num_head, + scaling_factor=self.head_dim**-0.5, + transpose=True, + use_neox_style=True, + ).reshape([b, -1, h]) + + assert paddle.allclose( + out_linear_out_ref.to("cpu").to("float32"), + out_linear_out.to("cpu").to("float32"), + rtol=0.2, + ), f"Test failed for {self.test_name} fused_fp8_block_attention out_linear_out" + + assert paddle.allclose( + self.k_cache.to("cpu").to("float32"), + self.k_cache_test.to("cpu").to("float32"), + rtol=1e-1, + ), f"Test failed for {self.test_name} fused_fp8_block_attention k_cache" + + assert paddle.allclose( + self.v_cache.to("cpu").to("float32"), + self.v_cache_test.to("cpu").to("float32"), + rtol=1e-2, + ), f"Test failed for {self.test_name} fused_fp8_block_attention v_cache" + + assert paddle.allclose( + self.residual.to("cpu").to("float32"), + self.residual_test.to("cpu").to("float32"), + rtol=1e-2, + ), f"Test failed for {self.test_name} fused_fp8_block_attention residual" + + # ===============summary============== + print(f"Test Pass for {self.test_name} testcase") + + +class test_case_decode_MHA(TestFusedBlockAttention): + def __init__(self): + super().__init__() + self.init_decode_MHA_params() + self.create_tensors() + + +class test_case_decode_GQA(TestFusedBlockAttention): + def __init__(self): + super().__init__() + self.init_decode_GQA_params() + self.create_tensors() + + +if __name__ == "__main__": + test_1 = test_case_decode_MHA() + test_1.run_test() + + test_2 = test_case_decode_GQA() + test_2.run_test()