From 2c60da49e401e094cbfcedbeccca32bb0947c1dc Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Thu, 25 Sep 2025 20:53:57 +0800 Subject: [PATCH 1/7] =?UTF-8?q?w4afp8=20=E6=94=AF=E6=8C=81per=20group?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_ops/gpu_ops/append_attn/utils.cuh | 32 +- .../fast_hardamard_kernel.h | 0 .../fast_hardamard_kernel.hpp} | 78 +---- .../fast_hardamard_kernel_bf16_bf16.cu | 34 +++ .../fast_hardamard_kernel_bf16_fp8.cu | 34 +++ .../fast_hardamard_kernel_bf16_int8.cu | 34 +++ .../fast_hardamard_kernel_fp16_fp16.cu | 34 +++ .../fast_hardamard_kernel_fp16_int8.cu | 34 +++ custom_ops/gpu_ops/moe/fused_moe_op.h | 92 ------ .../gpu_ops/moe/moe_expert_ffn_wint2.cu | 2 +- custom_ops/gpu_ops/moe/moe_ffn.cu | 48 +-- .../gpu_ops/w4afp8_gemm/kernel_traits.h | 53 ++-- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 278 +++++++++++------- custom_ops/gpu_ops/w4afp8_gemm/utils.hpp | 8 +- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 55 +--- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 2 - .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 128 ++++---- custom_ops/gpu_ops/w4afp8_gemm/weight.cu | 140 +++++++++ .../utils/auto_gen_w4afp8_gemm_kernel.py | 107 +++---- .../layers/moe/fused_moe_cutlass_backend.py | 4 +- 20 files changed, 634 insertions(+), 563 deletions(-) rename custom_ops/gpu_ops/moe/{ => fast_hardmard}/fast_hardamard_kernel.h (100%) rename custom_ops/gpu_ops/moe/{fast_hardamard_kernel.cu => fast_hardmard/fast_hardamard_kernel.hpp} (96%) create mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu create mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu create mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu create mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu create mode 100644 custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu create mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight.cu diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index 12d86dade8..eb7fb6b1ae 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -404,39 +404,9 @@ __forceinline__ __host__ __device__ void vec_cast( } #define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 1) { \ - constexpr size_t GROUP_SIZE = 1; \ - __VA_ARGS__ \ - } else if (group_size == 2) { \ - constexpr size_t GROUP_SIZE = 2; \ - __VA_ARGS__ \ - } else if (group_size == 3) { \ - constexpr size_t GROUP_SIZE = 3; \ - __VA_ARGS__ \ - } else if (group_size == 4) { \ - constexpr size_t GROUP_SIZE = 4; \ - __VA_ARGS__ \ - } else if (group_size == 5) { \ - constexpr size_t GROUP_SIZE = 5; \ - __VA_ARGS__ \ - } else if (group_size == 6) { \ - constexpr size_t GROUP_SIZE = 6; \ - __VA_ARGS__ \ - } else if (group_size == 7) { \ - constexpr size_t GROUP_SIZE = 7; \ - __VA_ARGS__ \ - } else if (group_size == 8) { \ + if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ - } else if (group_size == 12) { \ - constexpr size_t GROUP_SIZE = 12; \ - __VA_ARGS__ \ - } else if (group_size == 14) { \ - constexpr size_t GROUP_SIZE = 14; \ - __VA_ARGS__ \ - } else if (group_size == 16) { \ - constexpr size_t GROUP_SIZE = 16; \ - __VA_ARGS__ \ } else { \ PD_THROW("not support the group_size", group_size); \ } diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.h b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h similarity index 100% rename from custom_ops/gpu_ops/moe/fast_hardamard_kernel.h rename to custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.h diff --git a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp similarity index 96% rename from custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu rename to custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp index 1323cb4839..6d35654d8a 100644 --- a/custom_ops/gpu_ops/moe/fast_hardamard_kernel.cu +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp @@ -973,80 +973,4 @@ void MoeFastHardamardWrapper(const T *x_data, }); } } -} - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::float16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::float16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::float16 *shift, - const phi::dtype::float16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - phi::dtype::bfloat16 *out, - cudaStream_t &stream -); - -template void MoeFastHardamardWrapper( - const phi::dtype::bfloat16 *x_data, - const int64_t *expert_idx_per_token, - const int64_t *recv_expert_count, - const phi::dtype::bfloat16 *shift, - const phi::dtype::bfloat16 *smooth, - const float* quant_scales, - const int quant_round_type, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - const int hadamard_block_size, - int8_t *out, - cudaStream_t &stream -); +} \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu new file mode 100644 index 0000000000..64eb02e0be --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::bfloat16 *out, + cudaStream_t &stream +); \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu new file mode 100644 index 0000000000..bbc47d975b --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_fp8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::float8_e4m3fn *out, + cudaStream_t &stream +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu new file mode 100644 index 0000000000..d821c34ba2 --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_int8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::bfloat16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::bfloat16 *shift, + const phi::dtype::bfloat16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream +); \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu new file mode 100644 index 0000000000..ed4ab9df3d --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + phi::dtype::float16 *out, + cudaStream_t &stream +); \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu new file mode 100644 index 0000000000..f5b934b606 --- /dev/null +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "fast_hardamard_kernel.hpp" + +template void MoeFastHardamardWrapper( + const phi::dtype::float16 *x_data, + const int64_t *expert_idx_per_token, + const int64_t *recv_expert_count, + const phi::dtype::float16 *shift, + const phi::dtype::float16 *smooth, + const float* quant_scales, + const int quant_round_type, + const float quant_max_bound, + const float quant_min_bound, + const int64_t token_num, + const int64_t dim, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + const int hadamard_block_size, + int8_t *out, + cudaStream_t &stream +); \ No newline at end of file diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index eeaecb716f..6b743aa79c 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -261,98 +261,6 @@ void quantize_moe_input( out); } -template -__global__ void masked_compute_row_sum_kernel( -const T* permuted_inputs, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert) { -using LoadT = AlignedVector; -LoadT input_vec; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; - const auto expert_id = token_idx / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; - token_idx += num_iters_to_next_expert * gridDim.x; - continue; - } - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - thread_row_sum += static_cast(input_vec[i]); - } - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } -} - -template -__global__ void compute_row_sum_kernel( -const T* permuted_inputs, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert) { -using LoadT = AlignedVector; -LoadT input_vec; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - thread_row_sum += static_cast(input_vec[i]); - } - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } -} - -template -void compute_row_sum( - const T* permuted_inputs, - const int64_t token_num, - const int64_t dim, - float* permuted_input_row_sum, - const int64_t* recv_expert_count, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - cudaStream_t stream) { - constexpr int VecSize = 16 / sizeof(T); - constexpr int threads_per_block = 128; - const int dev_id = 0; - int sm_count; - int act_blocks_per_sm; - cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id); - assert(dim % VecSize == 0); - auto kernel = used_in_ep_low_latency ? masked_compute_row_sum_kernel : compute_row_sum_kernel; - cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &act_blocks_per_sm, kernel, threads_per_block, 0); - const int num_blocks_per_wave = sm_count * act_blocks_per_sm; - dim3 grid; - grid.x = min(static_cast(num_blocks_per_wave), token_num); - kernel<<>>( - permuted_inputs, - token_num, - dim, - permuted_input_row_sum, - recv_expert_count, - num_max_tokens_per_expert); - } - // ====================== Softmax things =============================== // We have our own implementation of softmax here so we can support transposing // the output in the softmax kernel when we extend this module to support diff --git a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu index f3e51bfcfa..47bcba3087 100644 --- a/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu +++ b/custom_ops/gpu_ops/moe/moe_expert_ffn_wint2.cu @@ -17,7 +17,7 @@ #include "cutlass/numeric_conversion.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" +#include "moe/fast_hardmard/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" template diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index c135903778..92af569971 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -18,7 +18,7 @@ #include "cutlass_kernels/w4a8_moe/w4a8_moe_gemm_kernel.h" #include "group_swiglu_with_masked.h" #include "helper.h" -#include "moe/fast_hardamard_kernel.h" +#include "moe/fast_hardmard/fast_hardamard_kernel.h" #include "moe/fused_moe_helper.h" #include "w4afp8_gemm/w4afp8_gemm.h" @@ -179,27 +179,11 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; - Allocator::AllocationPtr ffn1_input_row_sum; - ffn1_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - - compute_row_sum( - permute_input.data(), - expanded_active_expert_rows, - hidden_size, - reinterpret_cast(ffn1_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, - stream); - - float* row_scale = nullptr; DisPatchW4AFp8GemmWrapper( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn1_input_row_sum->ptr()), row_scale, const_cast(up_gate_proj_scale.get_ptr()) ->data(), @@ -323,18 +307,14 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, Allocator::AllocationPtr fp8_act_out; fp8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); - Allocator::AllocationPtr ffn2_input_row_sum; - ffn2_input_row_sum = allocator->Allocate( - sizeof(float) * expanded_active_expert_rows); - // note(yuanxiaolan): optimize this - MoeFastHardamardWrapper( + MoeFastHardamardWrapper( act_out_tensor.data(), expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, const_cast(tokens_expert_prefix_sum.data()), - ffn2_shift, // ffn2_shift->data(), - ffn2_smooth, // ffn2_smooth->data(), - nullptr, + ffn2_shift, + ffn2_smooth, + down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, 1, 448.0f, -448.0f, @@ -343,30 +323,14 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, num_max_tokens_per_expert, used_in_ep_low_latency, hadamard_block_size, - act_out_tensor.data(), - stream - ); - - quantize_moe_input(act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - reinterpret_cast(ffn2_input_row_sum->ptr()), - const_cast(tokens_expert_prefix_sum.data()), - num_max_tokens_per_expert, - used_in_ep_low_latency, reinterpret_cast(fp8_act_out->ptr()), stream - ); + ); DisPatchW4AFp8GemmWrapper( reinterpret_cast(fp8_act_out->ptr()), reinterpret_cast(down_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - reinterpret_cast(ffn2_input_row_sum->ptr()), row_scale, const_cast(down_proj_scale.get_ptr()) ->data(), diff --git a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h index 71e37a8ba3..4478a7aead 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h @@ -23,34 +23,35 @@ using namespace cute; -template +template struct SharedStorage { union { struct { cute::array_aligned> smem_a; cute::array_aligned> smem_b; + cute::array_aligned> smem_scale; }; cute::array_aligned> smem_c; }; - - struct { + + struct { typename cutlass::PipelineTmaAsync::SharedStorage pipeline; }; }; template struct Kernel_traits { using Element = elem_type; - using ElementAccum = float; using ElementOutput = OutputType; + using ElementAccum = typename std::conditional_t; static_assert(cutlass::sizeof_bits_v == 8); static constexpr int kNWarps = kNWarps_; @@ -66,10 +67,10 @@ struct Kernel_traits { static constexpr int kTiles = kTiles_; static constexpr int TokenPackSize = TokenPackSize_; static constexpr int M = M_; - static constexpr int TAIL_N = TAIL_N_; + static constexpr int K = K_; + static constexpr int WeightScaleGroup = WeightScaleGroup_; using TileShape_MNK = Shape, Int, Int>; - using TileShape_MNK_TAIL = Shape, Int, Int>; static constexpr int kClusterM = kClusterM_; using ClusterShape_MNK = Shape, _1, _1>; @@ -77,16 +78,12 @@ struct Kernel_traits { static constexpr int kStages = kStages_; static_assert(kStages > 1); - using AtomLayoutMNK = Layout, _1, _1>>; + using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), AtomLayoutMNK{})); - using TiledMma_TAIL = decltype(cute::make_tiled_mma( - cute::GMMA::rs_op_selector(), - AtomLayoutMNK{})); - using SmemLayoutAtomA = decltype( cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, Element, Int, Int>()); @@ -97,30 +94,16 @@ struct Kernel_traits { using SmemLayoutAtomB = decltype( cutlass::gemm::collective::detail::rs_smem_selector< - GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutB = decltype( tile_to_shape(SmemLayoutAtomB{}, make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}))); - - using SmemLayoutAtomB_TAIL = decltype( - cutlass::gemm::collective::detail::rs_smem_selector< - GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK_TAIL{})), - decltype(cute::get<2>(TileShape_MNK_TAIL{}))>()); - - using SmemLayoutB_TAIL = decltype( - tile_to_shape(SmemLayoutAtomB_TAIL{}, - make_shape( - shape<1>(TileShape_MNK_TAIL{}), - shape<2>(TileShape_MNK_TAIL{}), - Int{}) - )); - using SmemLayoutAtomC = decltype( cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, ElementOutput, - decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{}))); @@ -128,8 +111,10 @@ struct Kernel_traits { using SmemCopyAtomAB = Copy_Atom; using SmemCopyAtomC = Copy_Atom; + using SmemLayoutScale = Layout, Int>>; + using SharedStorage = SharedStorage< - kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC>; + kStages, Element, ElementOutput, float, SmemLayoutA, SmemLayoutB, SmemLayoutC, SmemLayoutScale>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; @@ -151,4 +136,4 @@ struct Kernel_traits { TiledCopyCThrLayout{}, // Thr layout TiledCopyCValLayout{} // Val layout )); -}; +}; \ No newline at end of file diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index cb46397d51..14cd7366b4 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -35,19 +35,19 @@ struct CollectiveMainloopFwd { using Element = typename Ktraits::Element; using ElementOutput = typename Ktraits::ElementOutput; using TileShape_MNK = typename Ktraits::TileShape_MNK; - using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL; using ClusterShape = typename Ktraits::ClusterShape_MNK; using ElementAccum = typename Ktraits::ElementAccum; static constexpr int kStages = Ktraits::kStages; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; - static constexpr int TAIL_N = Ktraits::TAIL_N; static constexpr int kBlockK = Ktraits::kBlockK; - static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int kTiles = Ktraits::kTiles; static constexpr int M = Ktraits::M; + static constexpr int K = Ktraits::K; static constexpr int TokenPackSize = Ktraits::TokenPackSize; + static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup; using GmemTiledCopy = cute::SM90_TMA_LOAD; @@ -55,12 +55,16 @@ struct CollectiveMainloopFwd { using SmemLayoutA = typename Ktraits::SmemLayoutA; using SmemLayoutB = typename Ktraits::SmemLayoutB; using SmemLayoutC = typename Ktraits::SmemLayoutC; - using SmemLayoutB_TAIL = typename Ktraits::SmemLayoutB_TAIL; + using SmemLayoutScale = typename Ktraits::SmemLayoutScale; using ShapeT = cute::Shape; using StrideT = cute::Shape; using LayoutT = cute::Layout; + using ShapeTScale = cute::Shape; + using StrideTScale = cute::Shape<_1, int64_t, int64_t>; + using LayoutTScale = cute::Layout; + using TMA_A = decltype(make_tma_copy( GmemTiledCopy{}, make_tensor( @@ -83,6 +87,17 @@ struct CollectiveMainloopFwd { select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{}))); + using TMA_Scale = decltype(make_tma_copy( + GmemTiledCopy{}, + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + ShapeTScale{}, + StrideTScale{} + ), + SmemLayoutScale{}(_, _0{}), + select<0>(Shape>{}), + size<0>(ClusterShape{}))); + static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); using MainloopPipeline = typename Ktraits::MainloopPipeline; using PipelineParams = typename MainloopPipeline::Params; @@ -93,6 +108,7 @@ struct CollectiveMainloopFwd { static constexpr uint32_t TmaTransactionBytesA = static_cast(size(take<0, 2>(SmemLayoutA{})) * cutlass::sizeof_bits_v / 8); static constexpr uint32_t TmaTransactionBytesB = static_cast(size(take<0, 2>(SmemLayoutB{})) * cutlass::sizeof_bits_v / 8); + static constexpr uint32_t TmaTransactionBytesScale = static_cast(size(SmemLayoutScale{}(_, _0{})) * cutlass::sizeof_bits_v / 8); struct Arguments { Element const* ptr_A; @@ -102,18 +118,19 @@ struct CollectiveMainloopFwd { ElementOutput * ptr_C; LayoutT layout_C; const float *weight_scale; - const float *input_row_sum; + LayoutTScale layout_Scale; const int64_t * tokens; }; struct Params { LayoutT layout_A; LayoutT layout_B; - TMA_A tma_load_A; + LayoutTScale layout_Scale; + TMA_A tma_load_A; TMA_B tma_load_B; + TMA_Scale tma_load_Scale; ElementOutput * ptr_C; const float *weight_scale; - const float *input_row_sum; const int64_t * tokens; }; @@ -134,24 +151,35 @@ struct CollectiveMainloopFwd { SmemLayoutB{}(_, _, _0{}), select<1, 2>(TileShape_MNK{}), size<0>(ClusterShape{})); - - return {args.layout_A, args.layout_B, tma_load_A, tma_load_B, - args.ptr_C, args.weight_scale, args.input_row_sum, args.tokens}; + Tensor mScale = make_tensor(make_gmem_ptr(args.weight_scale), args.layout_Scale); + TMA_Scale tma_load_Scale = make_tma_copy( + GmemTiledCopy{}, + mScale, + SmemLayoutScale{}(_, _0{}), + select<0>(Shape>{}), + size<0>(ClusterShape{})); + + return { + args.layout_A, args.layout_B, args.layout_Scale, + tma_load_A, tma_load_B, tma_load_Scale, + args.ptr_C, args.weight_scale, args.tokens}; } CUTLASS_DEVICE static void prefetch_tma_descriptors(Params const& mainloop_params) { cute::prefetch_tma_descriptor(mainloop_params.tma_load_A.get_tma_descriptor()); cute::prefetch_tma_descriptor(mainloop_params.tma_load_B.get_tma_descriptor()); + if constexpr (WeightScaleGroup < K) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_Scale.get_tma_descriptor()); + } } - template + template CUTLASS_DEVICE void store(Params const& mainloop_params, FrgTensorO & tOrO, SharedStorage& shared_storage, TiledMma tiled_mma, - const float *input_row_sum, const float *weight_scale, const int64_t tokens, const int64_t pre_fix_tokens, @@ -159,28 +187,35 @@ struct CollectiveMainloopFwd { const int bidn, const int bidb, const int tidx) { - + using packHalf = typename PackedHalf::Type; Tensor tOrO_out = make_tensor(tOrO.layout()); - #pragma unroll - for (int i = 0; i < size(tOrO); i+=4) { - const int sum_idx = i * 2; - tOrO[i] = (tOrO[i] + input_row_sum[sum_idx]) * weight_scale[0]; - tOrO[i + 1] = (tOrO[i + 1] + input_row_sum[sum_idx + 1]) * weight_scale[0]; - tOrO[i + 2] = (tOrO[i + 2] + input_row_sum[sum_idx]) * weight_scale[1]; - tOrO[i + 3] = (tOrO[i + 3] + input_row_sum[sum_idx + 1]) * weight_scale[1]; - *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); - *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + if constexpr (WeightScaleGroup == K) { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + tOrO[i] = (tOrO[i]) * weight_scale[0]; + tOrO[i + 1] = tOrO[i + 1] * weight_scale[0]; + tOrO[i + 2] = tOrO[i + 2] * weight_scale[1]; + tOrO[i + 3] = tOrO[i + 3] * weight_scale[1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + } + } else { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]), float(tOrO[i + 2])); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]), float(tOrO[i + 3])); + } } uint16_t *smem_c = reinterpret_cast(shared_storage.smem_c.data()); uint32_t * reg_data = reinterpret_cast(tOrO_out.data()); - + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - constexpr int k_copy_times = CUR_N / 16; + constexpr int k_copy_times = kBlockN / 16; #pragma unroll for (int i = 0; i < k_copy_times; i++) { @@ -193,16 +228,16 @@ struct CollectiveMainloopFwd { } cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); - const int batch_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; - ElementOutput * store_c = mainloop_params.ptr_C + batch_idx + bidn * (M * kBlockN) + bidm * kBlockM; - + const int expert_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; + ElementOutput * store_c = mainloop_params.ptr_C + expert_idx + bidn * (M * kBlockN) + bidm * kBlockM; + const int reamin_tokens = tokens - bidn * kBlockN; const int col = tidx % 2; constexpr int kPackSize = 16 / sizeof(ElementOutput); constexpr int kNumVecElem = kBlockM / kPackSize; - constexpr int copy_len = CUR_N * kNumVecElem; + constexpr int copy_len = kBlockN * kNumVecElem; #pragma unroll for (int idx = tidx; idx < copy_len; idx += NumMmaThreads) { const int idx_div2 = idx / 2; @@ -246,16 +281,17 @@ struct CollectiveMainloopFwd { Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + Tensor sScale = make_tensor(make_smem_ptr(shared_storage.smem_scale.data()), SmemLayoutScale{}); Tensor mA = mainloop_params.tma_load_A.get_tma_tensor(mainloop_params.layout_A.shape()); Tensor mB = mainloop_params.tma_load_B.get_tma_tensor(mainloop_params.layout_B.shape()); + Tensor mScale = mainloop_params.tma_load_Scale.get_tma_tensor(mainloop_params.layout_Scale.shape()); Tensor gA = local_tile(mA(_, _, bidb), select<0, 1>(Shape, Int>{}), make_coord(bidm, _)); + Tensor gScale = local_tile(mScale(_, bidm, bidb), select<0>(Shape>{}), make_coord(_)); auto [tAgA, tAsA] = tma_partition(mainloop_params.tma_load_A, _0{}, Layout{}, group_modes<0, 2>(sA), group_modes<0, 2>(gA)); - const int kIters = kTiles / kStages; - if constexpr (TokenPackSize == 0) { Tensor gB = get_local_no_packed_tensor( mB, @@ -267,90 +303,64 @@ struct CollectiveMainloopFwd { if (tidx == 0) { #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - const int i = kiter * kStages + s; - pipeline.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - - copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; - } - } - - #pragma unroll - for (int i = kIters * kStages; i < kTiles; ++i) { + for (int kiter = 0; kiter < kTiles; ++kiter) { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - + tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; + tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); + + if constexpr (WeightScaleGroup < K) { + copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + gScale(_, kiter), sScale(_, smem_pipe_write.index())); + } + + ++smem_pipe_write; } } } else { - auto mB_this_batch = make_tensor( - mB(_, _, bidb).data(), + auto mB_this_expert = make_tensor( + mB(_, _, bidb).data(), make_layout( cute::make_shape(tokens, size<1>(mB)), mB.stride() )); - Tensor gB = local_tile(mB_this_batch, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); + Tensor gB = local_tile(mB_this_expert, select<1, 2>(TileShape_MNK{}), make_coord(bidn, _)); auto [tBgB, tBsB] = tma_partition(mainloop_params.tma_load_B, _0{}, Layout{}, group_modes<0, 2>(sB), group_modes<0, 2>(gB)); if (tidx == 0) { #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - const int i = kiter * kStages + s; - pipeline.producer_acquire(smem_pipe_write); - copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - - copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; - } - } - - #pragma unroll - for (int i = kIters * kStages; i < kTiles; ++i) { + for (int kiter = 0; kiter < kTiles; ++kiter) { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tAgA(_, i), tAsA(_, smem_pipe_write.index())); - + tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), - tBgB(_, i), tBsB(_, smem_pipe_write.index())); - ++smem_pipe_write; + tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); + + if constexpr (WeightScaleGroup < K) { + copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), + gScale(_, kiter), sScale(_, smem_pipe_write.index())); + } + + ++smem_pipe_write; } } } } - template + template CUTLASS_DEVICE void mma(Params const& mainloop_params, TiledMma tiled_mma, MainloopPipeline pipeline, - PipelineState& smem_pipe_read, + PipelineState& smem_pipe_read, SharedStorage& shared_storage, FrgTensorO &tSrS, const int tidx) { - - using sMemBLayout = std::conditional_t< - CUR_N == kBlockN, - SmemLayoutB, - SmemLayoutB_TAIL - >; - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), sMemBLayout{}); - + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); tiled_mma.accumulate_ = GMMA::ScaleOut::One; auto threadMma = tiled_mma.get_thread_slice(tidx); @@ -358,37 +368,105 @@ struct CollectiveMainloopFwd { auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); - Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); + Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); Tensor tSrB = threadMma.partition_fragment_B(sB); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; + #pragma unroll + for (int kiter = 0; kiter < kTiles; ++kiter) { + Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); + consumer_wait(pipeline, smem_pipe_read); + gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, smem_pipe_read.index()), tSrS, smem_tiled_copy_A, smem_thr_copy_A); + pipeline.consumer_release(smem_pipe_read); + ++smem_pipe_read; + } + + } + + template + CUTLASS_DEVICE void + mma_pipeline(Params const& mainloop_params, + TiledMma tiled_mma, + MainloopPipeline pipeline, + PipelineState& smem_pipe_read, + SharedStorage& shared_storage, + FrgTensorO &tSrS, + const int tidx) { + + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem_a.data()), SmemLayoutA{}); + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem_b.data()), SmemLayoutB{}); + float2 *weight_scale = reinterpret_cast(shared_storage.smem_scale.data()) + tidx / 4; + + Tensor tSrS1 = make_fragment_like(tSrS); + Tensor tSrS2 = make_fragment_like(tSrS); - const int kIters = kTiles / kStages; + __half2 * tSrS_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS.data())); + __half2 * tSrS1_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS1.data())); + __half2 * tSrS2_data = reinterpret_cast<__half2*>(raw_pointer_cast(tSrS2.data())); - constexpr int B_STEPS = CUR_N == 0 ? 1 : (kBlockN / CUR_N); + auto threadMma = tiled_mma.get_thread_slice(tidx); + auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); + + Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); + Tensor tSrB = threadMma.partition_fragment_B(sB); + + auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + }; + + __half2 scale1, scale2, scale3, scale4; + float2 scale_cur_k; #pragma unroll - for (int kiter = 0; kiter < kIters; ++kiter) { - #pragma unroll - for (int s = 0; s < kStages; s++) { - Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, s)); + for (int kiter = 0; kiter < kTiles;) { + Tensor tSsA1 = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); + consumer_wait(pipeline, smem_pipe_read); + scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); + scale1 = __half2(scale_cur_k.x, scale_cur_k.x); + scale2 = __half2(scale_cur_k.y, scale_cur_k.y); + + gemm(tiled_mma, tSrA, tSsA1, tSrB(_, _, _, smem_pipe_read.index()), tSrS1, smem_tiled_copy_A, smem_thr_copy_A); + pipeline.consumer_release(smem_pipe_read); + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + if (kiter > 0) { + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); + } + } + + ++smem_pipe_read; + ++kiter; + + if (kiter < kTiles) { + Tensor tSsA2 = smem_thr_copy_A.partition_S(sA(_, _, smem_pipe_read.index())); consumer_wait(pipeline, smem_pipe_read); - gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, s * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A); + scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); + scale3 = __half2(scale_cur_k.x, scale_cur_k.x); + scale4 = __half2(scale_cur_k.y, scale_cur_k.y); + + gemm(tiled_mma, tSrA, tSsA2, tSrB(_, _, _, smem_pipe_read.index()), tSrS2, smem_tiled_copy_A, smem_thr_copy_A); pipeline.consumer_release(smem_pipe_read); ++smem_pipe_read; + ++kiter; + } + + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS1_data[i], scale1, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS1_data[i + 1], scale2, tSrS_data[i+1]); } + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; } - #pragma unroll - for (int i = 0; i < kTiles % kStages; ++i) { - Tensor tSsA = smem_thr_copy_A.partition_S(sA(_, _, i)); - consumer_wait(pipeline, smem_pipe_read); - gemm(tiled_mma, tSrA, tSsA, tSrB(_, _, _, i * B_STEPS), tSrS, smem_tiled_copy_A, smem_thr_copy_A); - pipeline.consumer_release(smem_pipe_read); - ++smem_pipe_read; + for (int i = 0; i < size(tSrS) / 2; i+=2) { + tSrS_data[i] = __hfma2(tSrS2_data[i], scale3, tSrS_data[i]); + tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); } } -}; +}; \ No newline at end of file diff --git a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp index 2c0f685fe7..cb1e9e9bbb 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp @@ -62,8 +62,12 @@ template __forceinline__ __device__ void convert_c4_2_fp8(const int32_t * src, int32_t * dst1, int32_t * dst2) { #pragma unroll for (int i = 0; i < numel; ++i) { - dst1[i] = (src[i] >> 4) & 0x0f0f0f0f; - dst2[i] = src[i] & 0x0f0f0f0f; + uint32_t head1 = src[i] & 0x80808080; + dst1[i] = (src[i] >> 4) & 0x07070707; + dst1[i] = dst1[i] | head1; + uint32_t head2 = (src[i] & 0x08080808) << 4; + dst2[i] = src[i] & 0x07070707; + dst2[i] = dst2[i] | head2; } } diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 53685c5c97..c5dd03c1dd 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -22,23 +22,7 @@ #include "w4afp8_gemm.h" -void weight_convert(const uint8_t *weight, uint8_t *weight_new, int batch, int M, int K) { - assert(K % 64 == 0); - for (int b = 0; b < batch; ++b) { - for (int m = 0; m < M; ++m) { - for (int k = 0; k < K; k+=64) { - for (int k_inner = 0; k_inner < 32; ++k_inner) { - uint8_t temp = 0; - uint8_t left = weight[b * M * K + m * K + k + k_inner]; - uint8_t right = weight[b * M * K + m * K + k + k_inner + 32]; - temp |= left << 4; - temp |= right; - weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast(&temp); - } - } - } - } -} + template class NVTraits; @@ -65,26 +49,23 @@ void DisPatchW4AFp8Gemm( const cutlass::float_e4m3_t* input, const cutlass::float_e4m3_t* weight, const int64_t * tokens, - const float * input_row_sum, const float * weight_scale, OutputType * out, const int64_t token_padding_size, const int64_t max_tokens, - const int batch_size, + const int Experts, const int64_t M, const int64_t K, cudaStream_t stream) { int kBlockN = 256; - int TailN = 0; if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( - M, K, batch_size, token_padding_size, kBlockN, TailN, + M, K, Experts, token_padding_size, kBlockN, K, weight, input, out, weight_scale, - input_row_sum, tokens, max_tokens, stream) @@ -97,14 +78,13 @@ std::vector W4AFp8Gemm( const paddle::Tensor& input, const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group - const paddle::Tensor& input_row_sum, const paddle::Tensor& weight_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16) { - const int batch_size = weight.dims()[0]; + const int Experts = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2] * 2; @@ -121,12 +101,11 @@ std::vector W4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), - input_row_sum.data(), weight_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, - batch_size, + Experts, M, K, input.stream()); @@ -136,18 +115,17 @@ std::vector W4AFp8Gemm( } } else { if (is_bfloat16) { - paddle::Tensor out = paddle::empty({batch_size, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place()); + paddle::Tensor out = paddle::empty({Experts, token_padding_size, M}, paddle::DataType::BFLOAT16, input.place()); phi::dtype::bfloat16 * out_data = out.data(); DisPatchW4AFp8Gemm( reinterpret_cast(input.data()), reinterpret_cast(weight.data()), tokens.data(), - input_row_sum.data(), weight_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, - batch_size, + Experts, M, K, input.stream()); @@ -163,7 +141,6 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t* total_rows_before_expert, - const float* input_row_sum, const float* row_scale, const float* weight_scale, OutputType * out, @@ -179,7 +156,6 @@ void DisPatchW4AFp8GemmWrapper( reinterpret_cast(input), reinterpret_cast(weight), total_rows_before_expert, - input_row_sum, weight_scale, reinterpret_cast(out), token_padding_size, @@ -191,14 +167,7 @@ void DisPatchW4AFp8GemmWrapper( } -std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { - const int batch_size = weight.dims()[0]; - const int M = weight.dims()[1]; - const int K = weight.dims()[2]; - paddle::Tensor weight_new = paddle::empty({batch_size, M, K / 2}, paddle::DataType::UINT8, weight.place()); - weight_convert(weight.data(), weight_new.data(), batch_size, M, K); - return {weight_new}; -} + template __global__ void permute_scale_kernel( @@ -261,7 +230,6 @@ PD_BUILD_STATIC_OP(w4afp8_gemm) .Inputs({"input", "weight", "tokens", - "input_row_sum", "weight_scale"}) .Outputs({"out"}) .Attrs({"token_padding_size: int64_t", @@ -269,16 +237,12 @@ PD_BUILD_STATIC_OP(w4afp8_gemm) "is_bfloat16: bool"}) .SetKernelFn(PD_KERNEL(W4AFp8Gemm)); -PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert) - .Inputs({"weight"}) - .Outputs({"converted_weight"}) - .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert)); + template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * input_row_sum, const float * row_scale, const float * weight_scale, __nv_bfloat16 * out, @@ -294,7 +258,6 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * input_row_sum, const float * row_scale, const float * weight_scale, half * out, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h index c2474d419b..16dd286452 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h @@ -24,7 +24,6 @@ std::vector W4AFp8Gemm( const paddle::Tensor& input, const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group - const paddle::Tensor& input_row_sum, const paddle::Tensor& weight_scale, const int64_t token_padding_size, const int64_t max_tokens, @@ -35,7 +34,6 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t * tokens, - const float * input_row_sum, const float * row_scale, const float * weight_scale, OutputType * out, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 01a8dd114c..400f068db4 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -34,7 +34,6 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp static_assert(cutlass::sizeof_bits_v == 8); using TileShape_MNK = typename Ktraits::TileShape_MNK; - using TileShape_MNK_TAIL = typename Ktraits::TileShape_MNK_TAIL; using ClusterShape = typename Ktraits::ClusterShape_MNK; static constexpr int NumMmaThreads = size(typename Ktraits::TiledMma{}); @@ -42,8 +41,9 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int M = Ktraits::M; + static constexpr int K = Ktraits::K; static constexpr int TokenPackSize = Ktraits::TokenPackSize; - static constexpr int TAIL_N = Ktraits::TAIL_N; + static constexpr int WeightScaleGroup = Ktraits::WeightScaleGroup; using CollectiveMainloop = CollectiveMainloopFwd; @@ -66,9 +66,13 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; - + PipelineParams pipeline_params; - pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB; + if constexpr (WeightScaleGroup == K) { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB; + } else { + pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB + CollectiveMainloop::TmaTransactionBytesScale; + } int warp_group_idx = cutlass::canonical_warp_group_idx(); pipeline_params.role = warp_group_idx == 0 ? MainloopPipeline::ThreadCategory::Producer @@ -96,9 +100,6 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp return; } - float* input_row_sum = reinterpret_cast( - shared_memory + sizeof(typename Ktraits::SharedStorage)); - if (warp_group_idx == 0) { cutlass::arch::warpgroup_reg_dealloc(); PipelineState smem_pipe_write = cutlass::make_producer_start_state(); @@ -119,95 +120,81 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp typename Ktraits::TiledMma tiled_mma; - typename Ktraits::TiledMma_TAIL tiled_mma_tail; - const int mma_tidx = tidx - NumCopyThreads; - const int lane_id = mma_tidx % 4 * 2; - const float2 weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; + float2 weight_scale; - if constexpr (TokenPackSize == 0) { - const int input_sum_idx = pre_fix_tokens + bidn * kBlockN; - if (mma_tidx < kBlockN) { - reinterpret_cast(input_row_sum)[mma_tidx] = reinterpret_cast(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx]; - } - } else { - const int input_sum_idx = bidb * TokenPackSize + bidn * kBlockN; - if (mma_tidx < kBlockN / 4) { - reinterpret_cast(input_row_sum)[mma_tidx] = reinterpret_cast(mainloop_params.input_row_sum + input_sum_idx)[mma_tidx]; - } + if constexpr (WeightScaleGroup == K) { + weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; } + Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - const int reamin_tokens = tokens - bidn * kBlockN; - - if (TAIL_N > 0 && reamin_tokens < kBlockN) { - Tensor tSrS_tail = partition_fragment_C(tiled_mma_tail, select<0, 1>(TileShape_MNK_TAIL{})); - collective_mainloop.mma( + if constexpr (WeightScaleGroup == K) { + collective_mainloop.mma( mainloop_params, - tiled_mma_tail, - pipeline, + tiled_mma, + pipeline, smem_pipe_read, shared_storage, - tSrS_tail, - mma_tidx); - collective_mainloop.store( - mainloop_params, - tSrS_tail, - shared_storage, - tiled_mma_tail, - input_row_sum + lane_id, - reinterpret_cast(&weight_scale), - tokens, - pre_fix_tokens, - bidm, - bidn, - bidb, + tSrS, mma_tidx); } else { - Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); - collective_mainloop.mma( + collective_mainloop.mma_pipeline( mainloop_params, tiled_mma, - pipeline, + pipeline, smem_pipe_read, shared_storage, tSrS, mma_tidx); - collective_mainloop.store( - mainloop_params, - tSrS, - shared_storage, - tiled_mma, - input_row_sum + lane_id, - reinterpret_cast(&weight_scale), - tokens, - pre_fix_tokens, - bidm, - bidn, - bidb, - mma_tidx); } + + + collective_mainloop.store( + mainloop_params, + tSrS, + shared_storage, + tiled_mma, + reinterpret_cast(&weight_scale), + tokens, + pre_fix_tokens, + bidm, + bidn, + bidb, + mma_tidx); } } -template +template auto get_gmem_layout(const int Rows, const int Cols) { return make_layout( make_shape( static_cast(Rows), static_cast(Cols), - static_cast(Batch)), + static_cast(Experts)), make_stride( static_cast(Cols), cute::_1{}, static_cast(Rows * Cols))); } +template +auto get_scale_layout(const int Rows, const int Cols) { + return make_layout( + make_shape( + static_cast(Cols), + static_cast(Rows), + static_cast(Experts)), + make_stride( + cute::_1{}, + static_cast(Cols), + static_cast(Rows * Cols))); +} + -template -void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, - const float *input_row_sum, const int64_t * tokens, const int64_t max_tokens, cudaStream_t stream) { +template +void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, const int64_t * tokens, const int max_tokens, cudaStream_t stream) { using ElementOutput = typename Kernel_traits::ElementOutput; using Element = typename Kernel_traits::Element; @@ -216,24 +203,27 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl constexpr int M_nums = (M + Kernel_traits::kBlockM - 1) / Kernel_traits::kBlockM; const int N_nums = (max_tokens + Kernel_traits::kBlockN - 1) / Kernel_traits::kBlockN; + constexpr int K_scale_nums = K / Kernel_traits::kBlockM; + static_assert(K % WeightScaleGroup == 0); + static_assert(WeightScaleGroup == 128 || WeightScaleGroup == K); typename CollectiveMainloop::Params mainloop_params = CollectiveMainloop::to_underlying_arguments({ static_cast(A), - get_gmem_layout(M, K / 2), + get_gmem_layout(M, K / 2), static_cast(B), - get_gmem_layout(TokenPackSize == 0 ? max_tokens: TokenPackSize, K), + get_gmem_layout(TokenPackSize == 0 ? max_tokens: TokenPackSize, K), static_cast(C), - get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), + get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), weight_scale, - input_row_sum, + get_scale_layout(M_nums, K_scale_nums * Kernel_traits::kBlockM), tokens }); void *kernel; kernel = (void *)w4afp8_gemm_kernel; - int smem_size = sizeof(typename Kernel_traits::SharedStorage) + sizeof(float) * Kernel_traits::kBlockN; + int smem_size = sizeof(typename Kernel_traits::SharedStorage); if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); @@ -242,7 +232,7 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl dim3 grid_dims; grid_dims.x = M_nums; grid_dims.y = N_nums; - grid_dims.z = Batch; + grid_dims.z = Experts; static constexpr int ctaSize = Kernel_traits::kNWarps * 32; dim3 block_dims(ctaSize); dim3 cluster_dims(size<0>(ClusterShape{}), size<1>(ClusterShape{}), size<2>(ClusterShape{})); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight.cu b/custom_ops/gpu_ops/w4afp8_gemm/weight.cu new file mode 100644 index 0000000000..b8191eeadd --- /dev/null +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight.cu @@ -0,0 +1,140 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +#include "helper.h" +#include "paddle/extension.h" + +void weight_convert(const uint8_t *weight, uint8_t *weight_new, int experts, int M, int K) { + assert(K % 64 == 0); + for (int b = 0; b < experts; ++b) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < K; k+=64) { + for (int k_inner = 0; k_inner < 32; ++k_inner) { + uint8_t temp = 0; + uint8_t left = weight[b * M * K + m * K + k + k_inner]; + uint8_t right = weight[b * M * K + m * K + k + k_inner + 32]; + temp |= left << 4; + temp |= right; + weight_new[b * M * K / 2 + m * K / 2 + k / 2 + k_inner] = *reinterpret_cast(&temp); + } + } + } + } +} + +__global__ void weight_permute_interleave_kernelw4afp8( + const int8_t* input_data, + int8_t* output_data, + const int original_k, + const int original_n) { + + const int numel = original_k * original_n / 4; + const int pack_group_size = 64; + const int thread_group_size = pack_group_size / 4; // 16 + const int thread_k_stride = original_k / 4; + + const int linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + + if (linear_idx >= numel) return; + + const int n_id = linear_idx / thread_k_stride; + const int k_id = linear_idx % thread_k_stride; + const int k_group_idx = k_id / thread_group_size; + const int k_idx_in_group = k_id % thread_group_size; + + const int8_t* src = input_data + + k_group_idx * pack_group_size / 2 * original_n + + k_idx_in_group * original_n + n_id; + + int8_t tmp0 = src[0]; + int8_t tmp1 = src[pack_group_size / 4 * original_n]; + + int8_t tmp00 = (tmp0 & 0xF0) + 112; + int8_t tmp01 = ((tmp0 << 4) & 0xF0) + 112; + int8_t tmp10 = (tmp1 & 0xF0) + 112; + int8_t tmp11 = ((tmp1 << 4) & 0xF0) + 112; + + uint8_t utmp00 = *(reinterpret_cast(&tmp00)); + uint8_t utmp01 = *(reinterpret_cast(&tmp01)); + uint8_t utmp10 = *(reinterpret_cast(&tmp10)); + uint8_t utmp11 = *(reinterpret_cast(&tmp11)); + + utmp00 = (utmp00 & 0xF0) >> 4; + utmp01 = (utmp01 & 0xF0) >> 4; + utmp10 = (utmp10 & 0xF0) >> 4; + utmp11 = (utmp11 & 0xF0) >> 4; + + tmp00 = *(reinterpret_cast(&utmp00)) - 7; + tmp01 = *(reinterpret_cast(&utmp01)) - 7; + tmp10 = *(reinterpret_cast(&utmp10)) - 7; + tmp11 = *(reinterpret_cast(&utmp11)) - 7; + + if (tmp00 <= 0) { + tmp00 = 8 - tmp00; + } + if (tmp01 <= 0) { + tmp01 = 8 - tmp01; + } + if (tmp10 <= 0) { + tmp10 = 8 - tmp10; + } + if (tmp11 <= 0) { + tmp11 = 8 - tmp11; + } + + int8_t dst0 = (tmp01 << 4) | tmp11; + int8_t dst1 = (tmp00 << 4) | tmp10; + + int8_t* dst = output_data + n_id * original_k / 2 + (k_group_idx * pack_group_size / 2) + k_idx_in_group * 2; + dst[0] = dst0; + dst[1] = dst1; +} + +std::vector W4AFp8GemmWeightPermute(const paddle::Tensor& weight) { + const int original_k = weight.dims()[0] * 2; + const int original_n = weight.dims()[1]; + paddle::Tensor weight_new = paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place()); + const int block_dim = 256; + const int original_numel = original_k * original_n; + const int grid_size = (original_numel + block_dim - 1) / block_dim; + + weight_permute_interleave_kernelw4afp8<<>>( + weight.data(), weight_new.data(), original_k, original_n); + + return {weight_new}; +} + +std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { + const int experts = weight.dims()[0]; + const int M = weight.dims()[1]; + const int K = weight.dims()[2]; + paddle::Tensor weight_new = paddle::empty({experts, M, K / 2}, paddle::DataType::UINT8, weight.place()); + weight_convert(weight.data(), weight_new.data(), experts, M, K); + return {weight_new}; +} + +PD_BUILD_STATIC_OP(w4afp8_gemm_weight_permute) + .Inputs({"weight"}) + .Outputs({"converted_weight"}) + .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightPermute)); + +PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert) + .Inputs({"weight"}) + .Outputs({"converted_weight"}) + .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert)); + diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 1acf3c80ae..edca4a740f 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import re file_dir = "./gpu_ops/w4afp8_gemm/" @@ -30,12 +32,11 @@ #include """ gemm_template_case = """ -void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( +void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, - const float *input_row_sum, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream); @@ -48,22 +49,21 @@ """ gemm_template_cu_template = """ -void w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}( +void w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}( const cutlass::float_e4m3_t * weight, const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, - const float *input_row_sum, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream) {{ constexpr static int M = {M}; constexpr static int K = {K}; - constexpr static int Batch = {BATCH}; + constexpr static int EXPERTS = {EXPERTS}; constexpr static int TokenPackSize = {PADDING}; constexpr static int kBlockN = {N}; - constexpr static int kBlockN_TAIL = {TAILN}; + constexpr static int kGroupSize = {GROUPSIZE}; constexpr static int kBlockM = 128; constexpr static int kBlockK = 128; constexpr static int kNWarps = 4 + kBlockM / 16; @@ -74,22 +74,24 @@ using Kernel_traits = Kernel_traits< kBlockM, kBlockN, kBlockK, kNWarps, kStages, kTiles, - M, TokenPackSize, kBlockN_TAIL, kCluster, cutlass::float_e4m3_t, + M, K, TokenPackSize, kGroupSize, kCluster, cutlass::float_e4m3_t, {cutlass_type}>; run_gemm - (weight, input, out, weight_scale, - input_row_sum, tokens, max_tokens, stream); + Kernel_traits, M, K, EXPERTS, TokenPackSize, kGroupSize> + (weight, input, out, weight_scale, tokens, max_tokens, stream); }} """ +# [M, K, Number of experts, token Padding Size, weight K group size] gemm_case = [ - [8192, 3584, 8, 0], # eb45T ffn1 - [8192, 3584, 8, 2048], # eb45T ffn1 - [7168, 8192, 8, 0], # eb45T ffn2 - [7168, 8192, 8, 2048], # eb45T ffn2 - [1792, 8192, 64, 0], # eb45t ffn1 - [8192, 896, 64, 0], # eb45t ffn2 + [8192, 3584, 8, 0, 3584], # eb45T ffn1 + [8192, 3584, 8, 2048, 3584], # eb45T ffn1 + [7168, 8192, 8, 0, 8192], # eb45T ffn2 + [7168, 8192, 8, 2048, 8192], # eb45T ffn2 + [1792, 8192, 64, 0, 8192], # eb45t ffn1 + [8192, 896, 64, 0, 896], # eb45t ffn2 + [1792, 8192, 64, 0, 128], # eb45t ffn1 + [8192, 896, 64, 0, 128], # eb45t ffn2 ] dtype = ["BF16"] @@ -97,6 +99,19 @@ use_fast_compile = True n_range = [256] if use_fast_compile else [i for i in range(16, 257, 16)] +all_cu_files = [] +for type in dtype: + for case in gemm_case: + for n in n_range: + all_cu_files.append(f"w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu") + +for file_path, empty_list, file_name_list in os.walk(file_dir): + for file_name in file_name_list: + if re.match(r'^w4afp8_gemm_M\d+_N\d+_.*\.cu$', file_name): + if file_name not in all_cu_files: + print("delete w4afp8 kernel file", file_path + file_name) + os.remove(file_path + file_name) + def get_cutlass_type(type): if type == "BF16": @@ -116,28 +131,16 @@ def get_cutlass_type(type): M=case[0], K=case[1], N=n, - BATCH=case[2], + EXPERTS=case[2], TYPE=type, PADDING=case[3], - TAILN=0, - cutlass_type=get_cutlass_type(type), - ) - ) - template_head_file.write( - gemm_template_case.format( - M=case[0], - K=case[1], - N=256, - BATCH=case[2], - TYPE=type, - PADDING=case[3], - TAILN=n - 16, + GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) template_cu_file = open( - f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_TAILN{0}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w" + f"{file_dir}w4afp8_gemm_M{case[0]}_N{n}_G{case[4]}_K{case[1]}_E{case[2]}_P{case[3]}_{type}.cu", "w" ) template_cu_file.write(gemm_template_cu_head) template_cu_file.write( @@ -145,29 +148,10 @@ def get_cutlass_type(type): M=case[0], K=case[1], N=n, - BATCH=case[2], + EXPERTS=case[2], TYPE=type, PADDING=case[3], - TAILN=0, - cutlass_type=get_cutlass_type(type), - ) - ) - - template_cu_file.close() - - template_cu_file = open( - f"{file_dir}w4afp8_gemm_M{case[0]}_N{256}_TAILN{n-16}_K{case[1]}_B{case[2]}_P{case[3]}_{type}.cu", "w" - ) - template_cu_file.write(gemm_template_cu_head) - template_cu_file.write( - gemm_template_cu_template.format( - M=case[0], - K=case[1], - N=256, - BATCH=case[2], - TYPE=type, - PADDING=case[3], - TAILN=n - 16, + GROUPSIZE=case[4], cutlass_type=get_cutlass_type(type), ) ) @@ -177,8 +161,8 @@ def get_cutlass_type(type): for type in dtype: template_head_file.write("\n") template_head_file.write( - """#define GEMM_SWITCH_{TYPE}(_M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN, ...) {{ \\ - if (_M == 0 && _K == 0 && _BATCH == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _TailN == 0) {{ \\""".format( + """#define GEMM_SWITCH_{TYPE}(_M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE, ...) {{ \\ + if (_M == 0 && _K == 0 && _EXPERTS == 0 && _TokenPaddingSize == 0 && _kBlockN == 0 && _GROUPSIZE == 0) {{ \\""".format( TYPE=type ) ) @@ -188,23 +172,16 @@ def get_cutlass_type(type): for case in gemm_case: for n in n_range: template_head_file.write( - """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ - w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( - M=case[0], K=case[1], N=n, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=0 - ) - ) - template_head_file.write("\n") - template_head_file.write( - """ }} else if (_M == {M} && _K == {K} && _BATCH == {BATCH} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _TailN == {TAILN}) {{ \\ - w4afp8_gemm_M{M}_N{N}_TAILN{TAILN}_K{K}_B{BATCH}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( - M=case[0], K=case[1], N=256, BATCH=case[2], TYPE=type, PADDING=case[3], TAILN=n - 16 + """ }} else if (_M == {M} && _K == {K} && _EXPERTS == {EXPERTS} && _TokenPaddingSize == {PADDING} && _kBlockN == {N} && _GROUPSIZE == {GROUPSIZE}) {{ \\ + w4afp8_gemm_M{M}_N{N}_G{GROUPSIZE}_K{K}_E{EXPERTS}_P{PADDING}_{TYPE}(__VA_ARGS__); \\""".format( + M=case[0], K=case[1], N=n, EXPERTS=case[2], TYPE=type, PADDING=case[3], GROUPSIZE=case[4] ) ) template_head_file.write("\n") template_head_file.write( """ } else { \\ - PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d batch=%d token_padding_size=%d kBlockN=%d tailN=%d\\n", _M, _K, _BATCH, _TokenPaddingSize, _kBlockN, _TailN)); \\ + PADDLE_THROW(phi::errors::Unimplemented("W4aFp8 not supported m=%d k=%d experts=%d token_padding_size=%d kBlockN=%d groupsize=%d\\n", _M, _K, _EXPERTS, _TokenPaddingSize, _kBlockN, _GROUPSIZE)); \\ } \\ }""" ) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index eaa46448c4..6acdc93aa9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -30,7 +30,7 @@ from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: - from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute + from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute, w4afp8_gemm_weight_permute except: logger.warning("import w4afp8_gemm_scale_permute Failed!") elif current_platform.is_iluvatar(): @@ -785,7 +785,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): weight_name = self.added_weight_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): - quant_weight, scale = weight_quantize(weight_tensor[i], algo=self.moe_quant_type, arch=80) + quant_weight = w4afp8_gemm_weight_permute(weight_tensor[i]) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) getattr(layer, weight_name).set_value(quanted_weight) From a9a4a892a5ce4974d80f365f975a1a45519d841c Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Fri, 26 Sep 2025 09:42:41 +0800 Subject: [PATCH 2/7] code style --- .../fast_hardmard/fast_hardamard_kernel.hpp | 2 +- .../fast_hardamard_kernel_bf16_bf16.cu | 2 +- .../fast_hardamard_kernel_bf16_int8.cu | 2 +- .../fast_hardamard_kernel_fp16_fp16.cu | 2 +- .../fast_hardamard_kernel_fp16_int8.cu | 2 +- .../gpu_ops/w4afp8_gemm/kernel_traits.h | 20 +++--- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 48 +++++++-------- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 6 +- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 18 +++--- .../{weight.cu => weight_kernel.hpp} | 61 +++++++------------ .../utils/auto_gen_w4afp8_gemm_kernel.py | 2 +- .../layers/moe/fused_moe_cutlass_backend.py | 5 +- 12 files changed, 80 insertions(+), 90 deletions(-) rename custom_ops/gpu_ops/w4afp8_gemm/{weight.cu => weight_kernel.hpp} (70%) diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp index 6d35654d8a..02906fa9de 100644 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel.hpp @@ -973,4 +973,4 @@ void MoeFastHardamardWrapper(const T *x_data, }); } } -} \ No newline at end of file +} diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu index 64eb02e0be..21800dc636 100644 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_bf16_bf16.cu @@ -31,4 +31,4 @@ template void MoeFastHardamardWrapper( const int hadamard_block_size, int8_t *out, cudaStream_t &stream -); \ No newline at end of file +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu index ed4ab9df3d..e61bf44378 100644 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_fp16.cu @@ -31,4 +31,4 @@ template void MoeFastHardamardWrapper( const int hadamard_block_size, phi::dtype::float16 *out, cudaStream_t &stream -); \ No newline at end of file +); diff --git a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu index f5b934b606..e4edb32b5f 100644 --- a/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu +++ b/custom_ops/gpu_ops/moe/fast_hardmard/fast_hardamard_kernel_fp16_int8.cu @@ -31,4 +31,4 @@ template void MoeFastHardamardWrapper( const int hadamard_block_size, int8_t *out, cudaStream_t &stream -); \ No newline at end of file +); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h index 4478a7aead..48025e8962 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/kernel_traits.h @@ -23,19 +23,19 @@ using namespace cute; -template struct SharedStorage { union { struct { cute::array_aligned> smem_a; cute::array_aligned> smem_b; - cute::array_aligned> smem_scale; + cute::array_aligned> smem_scale; }; cute::array_aligned> smem_c; }; - - struct { + + struct { typename cutlass::PipelineTmaAsync::SharedStorage pipeline; }; }; @@ -45,7 +45,7 @@ template struct Kernel_traits { @@ -78,7 +78,7 @@ struct Kernel_traits { static constexpr int kStages = kStages_; static_assert(kStages > 1); - using AtomLayoutMNK = Layout, _1, _1>>; + using AtomLayoutMNK = Layout, _1, _1>>; using TiledMma = decltype(cute::make_tiled_mma( cute::GMMA::rs_op_selector(), @@ -94,7 +94,7 @@ struct Kernel_traits { using SmemLayoutAtomB = decltype( cutlass::gemm::collective::detail::rs_smem_selector< - GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), + GMMA::Major::K, Element, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutB = decltype( @@ -103,7 +103,7 @@ struct Kernel_traits { using SmemLayoutAtomC = decltype( cutlass::gemm::collective::detail::rs_smem_selector< GMMA::Major::K, ElementOutput, - decltype(cute::get<0>(TileShape_MNK{})), + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<1>(TileShape_MNK{}))>()); using SmemLayoutC = decltype(tile_to_shape(SmemLayoutAtomC{}, select<0, 1>(TileShape_MNK{}))); @@ -114,7 +114,7 @@ struct Kernel_traits { using SmemLayoutScale = Layout, Int>>; using SharedStorage = SharedStorage< - kStages, Element, ElementOutput, float, SmemLayoutA, SmemLayoutB, SmemLayoutC, SmemLayoutScale>; + kStages, Element, ElementOutput, SmemLayoutA, SmemLayoutB, SmemLayoutC, SmemLayoutScale>; using MainloopPipeline = typename cutlass::PipelineTmaAsync; using PipelineState = typename cutlass::PipelineState; @@ -136,4 +136,4 @@ struct Kernel_traits { TiledCopyCThrLayout{}, // Thr layout TiledCopyCValLayout{} // Val layout )); -}; \ No newline at end of file +}; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index 14cd7366b4..2050bf862f 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -42,7 +42,7 @@ struct CollectiveMainloopFwd { static constexpr int kBlockM = Ktraits::kBlockM; static constexpr int kBlockN = Ktraits::kBlockN; static constexpr int kBlockK = Ktraits::kBlockK; - static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; + static constexpr int NumCopyThreads = cutlass::NumThreadsPerWarpGroup; static constexpr int kTiles = Ktraits::kTiles; static constexpr int M = Ktraits::M; static constexpr int K = Ktraits::K; @@ -90,8 +90,8 @@ struct CollectiveMainloopFwd { using TMA_Scale = decltype(make_tma_copy( GmemTiledCopy{}, make_tensor( - make_gmem_ptr(static_cast(nullptr)), - ShapeTScale{}, + make_gmem_ptr(static_cast(nullptr)), + ShapeTScale{}, StrideTScale{} ), SmemLayoutScale{}(_, _0{}), @@ -126,7 +126,7 @@ struct CollectiveMainloopFwd { LayoutT layout_A; LayoutT layout_B; LayoutTScale layout_Scale; - TMA_A tma_load_A; + TMA_A tma_load_A; TMA_B tma_load_B; TMA_Scale tma_load_Scale; ElementOutput * ptr_C; @@ -158,9 +158,9 @@ struct CollectiveMainloopFwd { SmemLayoutScale{}(_, _0{}), select<0>(Shape>{}), size<0>(ClusterShape{})); - + return { - args.layout_A, args.layout_B, args.layout_Scale, + args.layout_A, args.layout_B, args.layout_Scale, tma_load_A, tma_load_B, tma_load_Scale, args.ptr_C, args.weight_scale, args.tokens}; } @@ -187,7 +187,7 @@ struct CollectiveMainloopFwd { const int bidn, const int bidb, const int tidx) { - + using packHalf = typename PackedHalf::Type; Tensor tOrO_out = make_tensor(tOrO.layout()); @@ -212,7 +212,7 @@ struct CollectiveMainloopFwd { uint16_t *smem_c = reinterpret_cast(shared_storage.smem_c.data()); uint32_t * reg_data = reinterpret_cast(tOrO_out.data()); - + cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); constexpr int k_copy_times = kBlockN / 16; @@ -230,7 +230,7 @@ struct CollectiveMainloopFwd { cutlass::arch::NamedBarrier::sync(NumMmaThreads, 0); const int expert_idx = TokenPackSize == 0 ? pre_fix_tokens * M : bidb * M * TokenPackSize; ElementOutput * store_c = mainloop_params.ptr_C + expert_idx + bidn * (M * kBlockN) + bidm * kBlockM; - + const int reamin_tokens = tokens - bidn * kBlockN; const int col = tidx % 2; @@ -307,7 +307,7 @@ struct CollectiveMainloopFwd { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); - + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); @@ -315,13 +315,13 @@ struct CollectiveMainloopFwd { copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), gScale(_, kiter), sScale(_, smem_pipe_write.index())); } - + ++smem_pipe_write; } } } else { auto mB_this_expert = make_tensor( - mB(_, _, bidb).data(), + mB(_, _, bidb).data(), make_layout( cute::make_shape(tokens, size<1>(mB)), mB.stride() @@ -335,7 +335,7 @@ struct CollectiveMainloopFwd { pipeline.producer_acquire(smem_pipe_write); copy(mainloop_params.tma_load_A.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), tAgA(_, kiter), tAsA(_, smem_pipe_write.index())); - + copy(mainloop_params.tma_load_B.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), tBgB(_, kiter), tBsB(_, smem_pipe_write.index())); @@ -343,7 +343,7 @@ struct CollectiveMainloopFwd { copy(mainloop_params.tma_load_Scale.with(*pipeline.producer_get_barrier(smem_pipe_write), 0), gScale(_, kiter), sScale(_, smem_pipe_write.index())); } - + ++smem_pipe_write; } } @@ -355,7 +355,7 @@ struct CollectiveMainloopFwd { mma(Params const& mainloop_params, TiledMma tiled_mma, MainloopPipeline pipeline, - PipelineState& smem_pipe_read, + PipelineState& smem_pipe_read, SharedStorage& shared_storage, FrgTensorO &tSrS, const int tidx) { @@ -368,7 +368,7 @@ struct CollectiveMainloopFwd { auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); - Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); + Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); Tensor tSrB = threadMma.partition_fragment_B(sB); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { @@ -383,7 +383,7 @@ struct CollectiveMainloopFwd { pipeline.consumer_release(smem_pipe_read); ++smem_pipe_read; } - + } template @@ -391,7 +391,7 @@ struct CollectiveMainloopFwd { mma_pipeline(Params const& mainloop_params, TiledMma tiled_mma, MainloopPipeline pipeline, - PipelineState& smem_pipe_read, + PipelineState& smem_pipe_read, SharedStorage& shared_storage, FrgTensorO &tSrS, const int tidx) { @@ -412,14 +412,14 @@ struct CollectiveMainloopFwd { auto smem_tiled_copy_A = make_tiled_copy_A(SmemCopyAtomAB{}, tiled_mma); auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(tidx); - Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); + Tensor tSrA = threadMma.partition_fragment_A(sA(_, _, 0)); Tensor tSrB = threadMma.partition_fragment_B(sB); auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) { auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); pipeline.consumer_wait(smem_pipe_read, barrier_token); }; - + __half2 scale1, scale2, scale3, scale4; float2 scale_cur_k; #pragma unroll @@ -429,7 +429,7 @@ struct CollectiveMainloopFwd { scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); scale1 = __half2(scale_cur_k.x, scale_cur_k.x); scale2 = __half2(scale_cur_k.y, scale_cur_k.y); - + gemm(tiled_mma, tSrA, tSsA1, tSrB(_, _, _, smem_pipe_read.index()), tSrS1, smem_tiled_copy_A, smem_thr_copy_A); pipeline.consumer_release(smem_pipe_read); tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; @@ -440,7 +440,7 @@ struct CollectiveMainloopFwd { tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); } } - + ++smem_pipe_read; ++kiter; @@ -450,7 +450,7 @@ struct CollectiveMainloopFwd { scale_cur_k = *(weight_scale + smem_pipe_read.index() * (kBlockM / 2)); scale3 = __half2(scale_cur_k.x, scale_cur_k.x); scale4 = __half2(scale_cur_k.y, scale_cur_k.y); - + gemm(tiled_mma, tSrA, tSsA2, tSrB(_, _, _, smem_pipe_read.index()), tSrS2, smem_tiled_copy_A, smem_thr_copy_A); pipeline.consumer_release(smem_pipe_read); ++smem_pipe_read; @@ -469,4 +469,4 @@ struct CollectiveMainloopFwd { tSrS_data[i+1] = __hfma2(tSrS2_data[i + 1], scale4, tSrS_data[i+1]); } } -}; \ No newline at end of file +}; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index c5dd03c1dd..c1e6630ebf 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -20,6 +20,7 @@ #include "paddle/extension.h" #include "w4afp8_gemm_template.h" #include "w4afp8_gemm.h" +#include "weight_kernel.hpp" @@ -237,7 +238,10 @@ PD_BUILD_STATIC_OP(w4afp8_gemm) "is_bfloat16: bool"}) .SetKernelFn(PD_KERNEL(W4AFp8Gemm)); - +PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert) + .Inputs({"weight"}) + .Outputs({"converted_weight"}) + .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert)); template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* input, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 400f068db4..1c7250f6e1 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -66,7 +66,7 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp // Obtain warp index int const warp_group_thread_idx = threadIdx.x % cutlass::NumThreadsPerWarpGroup; - + PipelineParams pipeline_params; if constexpr (WeightScaleGroup == K) { pipeline_params.transaction_bytes = CollectiveMainloop::TmaTransactionBytesA + CollectiveMainloop::TmaTransactionBytesB; @@ -127,13 +127,13 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp if constexpr (WeightScaleGroup == K) { weight_scale = reinterpret_cast(mainloop_params.weight_scale + bidb * M + bidm * kBlockM)[mma_tidx / 4]; } - Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); + Tensor tSrS = partition_fragment_C(tiled_mma, select<0, 1>(TileShape_MNK{})); if constexpr (WeightScaleGroup == K) { collective_mainloop.mma( mainloop_params, tiled_mma, - pipeline, + pipeline, smem_pipe_read, shared_storage, tSrS, @@ -142,22 +142,22 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp collective_mainloop.mma_pipeline( mainloop_params, tiled_mma, - pipeline, + pipeline, smem_pipe_read, shared_storage, tSrS, mma_tidx); } - + collective_mainloop.store( - mainloop_params, - tSrS, - shared_storage, + mainloop_params, + tSrS, + shared_storage, tiled_mma, reinterpret_cast(&weight_scale), tokens, - pre_fix_tokens, + pre_fix_tokens, bidm, bidn, bidb, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight.cu b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp similarity index 70% rename from custom_ops/gpu_ops/w4afp8_gemm/weight.cu rename to custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp index b8191eeadd..f155515ed8 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/weight.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp @@ -11,11 +11,6 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - #include "helper.h" #include "paddle/extension.h" @@ -38,9 +33,9 @@ void weight_convert(const uint8_t *weight, uint8_t *weight_new, int experts, int } __global__ void weight_permute_interleave_kernelw4afp8( - const int8_t* input_data, - int8_t* output_data, - const int original_k, + const int8_t* input_data, + int8_t* output_data, + const int original_k, const int original_n) { const int numel = original_k * original_n / 4; @@ -105,36 +100,24 @@ __global__ void weight_permute_interleave_kernelw4afp8( dst[1] = dst1; } -std::vector W4AFp8GemmWeightPermute(const paddle::Tensor& weight) { - const int original_k = weight.dims()[0] * 2; - const int original_n = weight.dims()[1]; - paddle::Tensor weight_new = paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place()); - const int block_dim = 256; - const int original_numel = original_k * original_n; - const int grid_size = (original_numel + block_dim - 1) / block_dim; - - weight_permute_interleave_kernelw4afp8<<>>( - weight.data(), weight_new.data(), original_k, original_n); - - return {weight_new}; -} - std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { - const int experts = weight.dims()[0]; - const int M = weight.dims()[1]; - const int K = weight.dims()[2]; - paddle::Tensor weight_new = paddle::empty({experts, M, K / 2}, paddle::DataType::UINT8, weight.place()); - weight_convert(weight.data(), weight_new.data(), experts, M, K); - return {weight_new}; + if (weight.place() == paddle::PlaceType::CPU) { + const int experts = weight.dims()[0]; + const int M = weight.dims()[1]; + const int K = weight.dims()[2]; + paddle::Tensor weight_new = paddle::empty({experts, M, K / 2}, paddle::DataType::UINT8, weight.place()); + weight_convert(weight.data(), weight_new.data(), experts, M, K); + return {weight_new}; + } else { + const int original_k = weight.dims()[0] * 2; + const int original_n = weight.dims()[1]; + paddle::Tensor weight_new = paddle::empty(weight.shape(), paddle::DataType::INT8, weight.place()); + const int block_dim = 256; + const int original_numel = original_k * original_n; + const int grid_size = (original_numel + block_dim - 1) / block_dim; + + weight_permute_interleave_kernelw4afp8<<>>( + weight.data(), weight_new.data(), original_k, original_n); + return {weight_new}; + } } - -PD_BUILD_STATIC_OP(w4afp8_gemm_weight_permute) - .Inputs({"weight"}) - .Outputs({"converted_weight"}) - .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightPermute)); - -PD_BUILD_STATIC_OP(w4afp8_gemm_weight_convert) - .Inputs({"weight"}) - .Outputs({"converted_weight"}) - .SetKernelFn(PD_KERNEL(W4AFp8GemmWeightConvert)); - diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index edca4a740f..9cf502236e 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -107,7 +107,7 @@ for file_path, empty_list, file_name_list in os.walk(file_dir): for file_name in file_name_list: - if re.match(r'^w4afp8_gemm_M\d+_N\d+_.*\.cu$', file_name): + if re.match(r"^w4afp8_gemm_M\d+_N\d+_.*\.cu$", file_name): if file_name not in all_cu_files: print("delete w4afp8 kernel file", file_path + file_name) os.remove(file_path + file_name) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 6acdc93aa9..1829e3fd32 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -30,7 +30,10 @@ from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce try: - from fastdeploy.model_executor.ops.gpu import w4afp8_gemm_scale_permute, w4afp8_gemm_weight_permute + from fastdeploy.model_executor.ops.gpu import ( + w4afp8_gemm_scale_permute, + w4afp8_gemm_weight_permute, + ) except: logger.warning("import w4afp8_gemm_scale_permute Failed!") elif current_platform.is_iluvatar(): From 2d691dfdf463d1c418e94ee8a80f2865b1ad077f Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Fri, 26 Sep 2025 15:08:05 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E7=B2=BE=E5=BA=A6=E5=AE=8C=E6=88=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_ops/gpu_ops/moe/moe_ffn.cu | 14 ++-- custom_ops/gpu_ops/w4afp8_gemm/utils.hpp | 3 +- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 65 +++--------------- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 1 + .../gpu_ops/w4afp8_gemm/weight_kernel.hpp | 2 +- .../w4afp8_gemm/weight_scale_kernel.hpp | 66 +++++++++++++++++++ docs/features/plas_attention.md | 4 +- docs/zh/features/plas_attention.md | 4 +- .../layers/moe/fused_moe_cutlass_backend.py | 45 ++++++++++++- tests/operators/test_w4afp8_gemm.py | 29 ++++---- 10 files changed, 150 insertions(+), 83 deletions(-) create mode 100644 custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 92af569971..fedb8b2c97 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -178,21 +178,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typedef PDTraits traits_fp8; typedef typename traits_fp8::DataType DataType_fp8; typedef typename traits_fp8::data_t data_t_fp8; - + paddle::Tensor weight_scale_tensor = *const_cast(up_gate_proj_scale.get_ptr()); + const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? hidden_size : weight_scale_tensor.dims()[3]; float* row_scale = nullptr; DisPatchW4AFp8GemmWrapper( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), row_scale, - const_cast(up_gate_proj_scale.get_ptr()) - ->data(), + weight_scale_tensor.data(), reinterpret_cast(fc1_out), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : permute_input.dims()[0], num_experts, inter_size, hidden_size, + weight_scale_group_size, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; @@ -327,19 +328,22 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, stream ); + paddle::Tensor weight_scale_tensor = *const_cast(down_proj_scale.get_ptr()); + const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? inter_size / 2 : weight_scale_tensor.dims()[3]; + DisPatchW4AFp8GemmWrapper( reinterpret_cast(fp8_act_out->ptr()), reinterpret_cast(down_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), row_scale, - const_cast(down_proj_scale.get_ptr()) - ->data(), + weight_scale_tensor.data(), reinterpret_cast(ffn_out_data), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, used_in_ep_low_latency ? num_max_tokens_per_expert : act_out_tensor.dims()[0], num_experts, hidden_size, inter_size / 2, + weight_scale_group_size, stream); } else { typename cutlass::WintQuantTraits::Arguments quant_args; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp index cb1e9e9bbb..128ea564aa 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/utils.hpp @@ -92,7 +92,6 @@ __forceinline__ __device__ void gemm( warpgroup_arrive(); } constexpr int numel = decltype(size(tCrA(_, _, 0)))::value / 4; - Tensor tCrA_copy_view = thr_copy_A.retile_D(tCrA); cute::copy(tiled_copy_A, tCsA(_, _, _0{}), tCrA_copy_view(_, _, _0{})); @@ -107,7 +106,9 @@ __forceinline__ __device__ void gemm( convert_c4_2_fp8(tCrA_data, tCrA1_data, tCrA2_data); cute::gemm(tiled_mma, tCrA1(_,_,k_block), tCrB(_,_,2 * k_block), tCrC); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; cute::gemm(tiled_mma, tCrA2(_,_,k_block), tCrB(_,_, 2 * k_block + 1), tCrC); + } if constexpr (commit) { warpgroup_commit_batch(); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index c1e6630ebf..1d307bbd7c 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -21,6 +21,7 @@ #include "w4afp8_gemm_template.h" #include "w4afp8_gemm.h" #include "weight_kernel.hpp" +#include "weight_scale_kernel.hpp" @@ -57,12 +58,13 @@ void DisPatchW4AFp8Gemm( const int Experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream) { int kBlockN = 256; if constexpr (std::is_same_v) { GEMM_SWITCH_BF16( - M, K, Experts, token_padding_size, kBlockN, K, + M, K, Experts, token_padding_size, kBlockN, WeightScaleGroup, weight, input, out, @@ -88,6 +90,7 @@ std::vector W4AFp8Gemm( const int Experts = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2] * 2; + const int WeightScaleGroup = weight_scale.dims().size() == 2 ? K : weight_scale.dims()[3]; if (input.dtype() != paddle::DataType::FLOAT8_E4M3FN) { PD_THROW("Only supported dtype in ['FLOAT8_E4M3FN']."); @@ -109,6 +112,7 @@ std::vector W4AFp8Gemm( Experts, M, K, + WeightScaleGroup, input.stream()); return {out}; } else { @@ -129,6 +133,7 @@ std::vector W4AFp8Gemm( Experts, M, K, + WeightScaleGroup, input.stream()); return {out}; } else { @@ -150,6 +155,7 @@ void DisPatchW4AFp8GemmWrapper( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream) { using InType = typename NVTraits::data_t; using OutType = typename NVTraits::data_t; @@ -164,63 +170,10 @@ void DisPatchW4AFp8GemmWrapper( num_experts, M, K, + WeightScaleGroup, stream); } - - - -template -__global__ void permute_scale_kernel( - T* input_data, - const int numel) { - using LoadT = AlignedVector; - LoadT input_vec; - LoadT dst_vec; - const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize; - if (load_idx >= numel) { - return; - } - Load(&input_data[load_idx], &input_vec); - - for (int i = 0; i < kPackSize; i+=2) { - dst_vec[i] = input_vec[i / 2]; - dst_vec[i + 1] = input_vec[i / 2 + 8]; - } - - Store(dst_vec, &input_data[load_idx]); -} - -void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { - const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; - const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; - if (col % 16 != 0) { - PD_THROW("Only supported when col is divisible by 16."); - } - const int numel = row * col; - const int threads = 128; - const int kPackSize = 16; - const int grid_size = (numel / kPackSize + threads - 1) / threads; - - if (scale.dtype() == paddle::DataType::BFLOAT16) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } else if (scale.dtype() == paddle::DataType::FLOAT16) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } else if (scale.dtype() == paddle::DataType::FLOAT32) { - permute_scale_kernel<<>>( - const_cast(scale.data()), - numel - ); - } - -} - PD_BUILD_STATIC_OP(w4afp8_gemm_scale_permute) .Inputs({"weight_scale"}) .Outputs({"permute_scale"}) @@ -255,6 +208,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream ); @@ -270,5 +224,6 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream ); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h index 16dd286452..1c1db0e12c 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h @@ -42,4 +42,5 @@ void DisPatchW4AFp8GemmWrapper( const int num_experts, const int64_t M, const int64_t K, + const int WeightScaleGroup, cudaStream_t stream); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp index f155515ed8..7501bdaebc 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight_kernel.hpp @@ -101,7 +101,7 @@ __global__ void weight_permute_interleave_kernelw4afp8( } std::vector W4AFp8GemmWeightConvert(const paddle::Tensor& weight) { - if (weight.place() == paddle::PlaceType::CPU) { + if (weight.place() == paddle::CPUPlace()) { const int experts = weight.dims()[0]; const int M = weight.dims()[1]; const int K = weight.dims()[2]; diff --git a/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp new file mode 100644 index 0000000000..9a37bf4a2f --- /dev/null +++ b/custom_ops/gpu_ops/w4afp8_gemm/weight_scale_kernel.hpp @@ -0,0 +1,66 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "helper.h" +#include "paddle/extension.h" + +template +__global__ void permute_scale_kernel( + T* input_data, + const int numel) { + using LoadT = AlignedVector; + LoadT input_vec; + LoadT dst_vec; + const int load_idx = (blockIdx.x * blockDim.x + threadIdx.x) * kPackSize; + if (load_idx >= numel) { + return; + } + Load(&input_data[load_idx], &input_vec); + + for (int i = 0; i < kPackSize; i+=2) { + dst_vec[i] = input_vec[i / 2]; + dst_vec[i + 1] = input_vec[i / 2 + 8]; + } + + Store(dst_vec, &input_data[load_idx]); +} + +void W4AFp8GemmScalePermute(const paddle::Tensor& scale) { + const int row = scale.dims().size() == 2 ? scale.dims()[0] : 1; + const int col = scale.dims().size() == 2 ? scale.dims()[1] : scale.dims()[0]; + if (col % 16 != 0) { + PD_THROW("Only supported when col is divisible by 16."); + } + const int numel = row * col; + const int threads = 128; + const int kPackSize = 16; + const int grid_size = (numel / kPackSize + threads - 1) / threads; + + if (scale.dtype() == paddle::DataType::BFLOAT16) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } else if (scale.dtype() == paddle::DataType::FLOAT16) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } else if (scale.dtype() == paddle::DataType::FLOAT32) { + permute_scale_kernel<<>>( + const_cast(scale.data()), + numel + ); + } + +} diff --git a/docs/features/plas_attention.md b/docs/features/plas_attention.md index 6551b4e8f7..b096fceeb6 100644 --- a/docs/features/plas_attention.md +++ b/docs/features/plas_attention.md @@ -15,7 +15,7 @@ In terms of training efficiency, the training cost is very low because only the Following the approaches of NSA and MoBA, we partition the KV into multiple blocks. During both the prefill and decode stages, instead of performing attention computation over all KV, we dynamically select the top-K blocks with the highest attention scores for each query token, thereby enabling efficient sparse attention computation.
-Attention Gate Module +Attention Gate Module
* **Attention Gate Module**: As illustrated in the figure above, to estimate the importance of each block with low computational overhead, we design a lightweight attention gate module. This module first compresses each K block via a MLP layer to generate a representative low-dimensional representation: $K_c^T=W_{kp}K^T$, where $W_{kp}$ denotes the MLP layer weights. Compared to directly applying mean pooling, the learnable MLP can more effectively capture semantic relationships and importance distributions among different tokens, thereby providing a refined representation of each block. After obtaining the compressed representation $K_c$, the importance of each query token with respect to each block is estimated via: $Softmax(Q\cdot K_c^T)$. To enhance the discriminative ability of the MLP layer, we use the full attention result after 1D max pooling $1DMaxPooling(Softmax(Q \cdot K^T))$ as the ground truth. By minimizing the distribution divergence between the two, the MLP layer is guided to learn feature representations that better align with the true attention distribution. @@ -27,7 +27,7 @@ Following the approaches of NSA and MoBA, we partition the KV into multiple bloc During sparse attention computation, each query token may dynamically select different KV blocks, leading to highly irregular memory access patterns in HBM. It is feasible to simply process each query token separately, but it will lead to excessively fine-grained computing, which cannot make full use of the tensor core, thus significantly reducing the GPU computing efficiency.
-Token/Head Union +Token/Head Union
To optimize performance in both the prefill and decode stages, we design a special joint strategy to adapt to their respective characteristics: diff --git a/docs/zh/features/plas_attention.md b/docs/zh/features/plas_attention.md index a49cb25fde..0d8fcb2b97 100644 --- a/docs/zh/features/plas_attention.md +++ b/docs/zh/features/plas_attention.md @@ -15,7 +15,7 @@ 借鉴 NSA 和 MoBA 的方法,我们将键值对 (KV) 划分为多个块。在预填充和解码阶段,我们不再对所有键值进行注意力计算,而是动态地为每个查询 token 选择注意力得分最高的前 K 个块,从而实现高效的稀疏注意力计算。
-Attention Gate Module +Attention Gate Module
* **Attention Gate Module**: 如上图所示,为了以较低的计算开销估计每个块的重要性,我们设计了一个轻量级的注意力门模块。该模块首先通过一个MLP层压缩每个K个块,生成一个具有代表性的低维表示: $K_c^T=W_{kp}K^T$ ,其中 $W_{kp}$ 表示 MLP 层的权重。与直接应用均值池化相比,可学习的 MLP 可以更有效地捕捉不同 token 之间的语义关系和重要性分布,从而提供每个块的精细表示。在获得压缩表示 $K_c$ 之后,通过以下公式估计每个查询 token 相对于每个块的重要性:$Softmax(Q\cdot K_c^T)$。为了增强 MLP 层的判别能力,我们使用一维最大池化后的完整注意力结果 $1DMaxPooling(Softmax(Q \cdot K^T))$ 作为 ground truth。通过最小化两者之间的分布差异,引导 MLP 层学习更符合真实注意力分布的特征表示。 @@ -29,7 +29,7 @@ 在稀疏注意力计算过程中,每个查询 token 可能会动态选择不同的 KV 块,导致 HBM 的内存访问模式非常不规则。简单地对每个查询 token 进行单独处理是可行的,但这会导致计算粒度过细,无法充分利用张量核,从而显著降低 GPU 的计算效率。
-Token/Head Union +Token/Head Union
为了优化预填充和解码阶段的性能,我们设计了一种特殊的联合策略来适应各自的特点: diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 1829e3fd32..cc02bf7f18 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -32,7 +32,7 @@ try: from fastdeploy.model_executor.ops.gpu import ( w4afp8_gemm_scale_permute, - w4afp8_gemm_weight_permute, + w4afp8_gemm_weight_convert, ) except: logger.warning("import w4afp8_gemm_scale_permute Failed!") @@ -788,7 +788,7 @@ def process_loaded_weights(self, layer: nn.Layer, state_dict): weight_name = self.added_weight_attrs[idx] weight_list = [] for i in range(layer.num_local_experts): - quant_weight = w4afp8_gemm_weight_permute(weight_tensor[i]) + quant_weight = w4afp8_gemm_weight_convert(weight_tensor[i]) weight_list.append(quant_weight) quanted_weight = paddle.stack(weight_list, axis=0) getattr(layer, weight_name).set_value(quanted_weight) @@ -888,7 +888,46 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process processed_weight_scale = ( paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] ) - processed_weight_scale = _permute_weight_scale(processed_weight_scale) + + if len(processed_weight_scale.shape) == 3: + if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size: + assert ( + layer.hidden_size // 128 % processed_weight_scale.shape[-1] == 0 + ), "weight_scale_group_size must be a multiple of 128" + # If it is a multiple of 128, repeat to 128 + processed_weight_scale = processed_weight_scale.repeat_interleave( + layer.hidden_size // 128 // processed_weight_scale.shape[-1], axis=-1 + ) + elif name == "down_proj_weight_scale": + assert ( + layer.moe_intermediate_size // 128 % processed_weight_scale.shape[-1] == 0 + ), "weight_scale_group_size must be a multiple of 128" + # If it is a multiple of 128, repeat to 128 + processed_weight_scale = processed_weight_scale.repeat_interleave( + layer.moe_intermediate_size // 128 // processed_weight_scale.shape[-1], axis=-1 + ) + else: + raise ValueError(f"Invalid weight scale name: {name}") + + origin_shape = processed_weight_scale.shape + processed_weight_scale = processed_weight_scale.transpose([0, 2, 1]) + processed_weight_scale = processed_weight_scale.reshape([-1, processed_weight_scale.shape[-1]]) + processed_weight_scale = _permute_weight_scale(processed_weight_scale) + processed_weight_scale = processed_weight_scale.reshape( + [origin_shape[0], origin_shape[2], origin_shape[1] // 128, 128] + ) + processed_weight_scale = processed_weight_scale.transpose([0, 2, 1, 3]) + setattr( + layer, + name, + layer.create_parameter( + shape=processed_weight_scale.shape, + dtype="float32", + default_initializer=paddle.nn.initializer.Constant(0), + ), + ) + else: + processed_weight_scale = _permute_weight_scale(processed_weight_scale) getattr(layer, name).set_value(processed_weight_scale) # 1. Init scale containers and maps diff --git a/tests/operators/test_w4afp8_gemm.py b/tests/operators/test_w4afp8_gemm.py index 29459ddf39..9c97840575 100644 --- a/tests/operators/test_w4afp8_gemm.py +++ b/tests/operators/test_w4afp8_gemm.py @@ -23,10 +23,10 @@ class TestW4AFP8GEMM(unittest.TestCase): def setUp(self): paddle.seed(0) - self.tokens_per_group = 256 - self.N = 256 - self.K = 256 - self.BATCH = 1 + self.tokens_per_group = 1 + self.N = 1792 + self.K = 8192 + self.BATCH = 64 self.TokenPadding = 0 tokens = [self.tokens_per_group] * self.BATCH @@ -38,14 +38,15 @@ def setUp(self): self.input_fp8 = paddle.randn([self.all_tokens, self.K], dtype="bfloat16").astype(paddle.float8_e4m3fn) self.input_bf16 = self.input_fp8.astype("bfloat16") - self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") / 10 + self.weight = paddle.randn([self.BATCH, self.N, self.K], dtype="bfloat16") self.weight_scale = 7 / self.weight.abs().max(axis=-1).reshape([self.BATCH, self.N, 1]) - self.weight_quant = (self.weight * self.weight_scale).astype("int") + 7 - self.weight_quant = paddle.clip(self.weight_quant, 0, 14) + self.weight_quant = (self.weight * self.weight_scale).astype("int") + self.weight_quant = paddle.clip(self.weight_quant, -7, 7) + self.weight_quant_naive = self.weight_quant.astype("float32") self.weight_quant = self.weight_quant.astype("bfloat16") + self.weight_quant = paddle.where(self.weight_quant > 0, self.weight_quant, 8 - self.weight_quant) self.weight_dequant_scale = 1 / self.weight_scale.astype("float32") - self.input_row_sum = self.input_bf16.sum(axis=1) * -7 / 512 self.max_tokens = int(self.tokens.max()) def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_scale): @@ -54,7 +55,7 @@ def w4afp8_gemm_naive(self, input_bf16, weight_quant, tokens, weight_dequant_sca pre_fix_token = 0 for i in range(self.BATCH): input = input_bf16[pre_fix_token : pre_fix_token + tokens[i], :] - weight = (weight_quant[i] - 7.0) * weight_dequant_scale[i] + weight = weight_quant[i] * weight_dequant_scale[i] out_i = paddle.matmul(input, weight.astype("bfloat16"), transpose_y=True) out[pre_fix_token : pre_fix_token + tokens[i], :] = out_i pre_fix_token += tokens[i] @@ -72,7 +73,9 @@ def permute_scale(self, weight_scale): return weight_scale def test_w4afp8_gemm(self): - out_naive = self.w4afp8_gemm_naive(self.input_bf16, self.weight_quant, self.tokens, self.weight_dequant_scale) + out_naive = self.w4afp8_gemm_naive( + self.input_bf16, self.weight_quant_naive, self.tokens, self.weight_dequant_scale + ) weight_dequant_scale = paddle.to_tensor(self.permute_scale(self.weight_dequant_scale) * 512) weight_int4 = w4afp8_gemm_weight_convert(self.weight_quant.astype("uint8").cpu()) @@ -82,10 +85,9 @@ def test_w4afp8_gemm(self): self.input_fp8, weight_int4.cuda(), self.tokens_prefix_sum, - self.input_row_sum.astype("float32"), weight_dequant_scale.astype("float32"), int(self.TokenPadding), - self.max_tokens, + self.all_tokens, True, ) else: @@ -93,7 +95,6 @@ def test_w4afp8_gemm(self): self.input_fp8, weight_int4.cuda(), self.tokens, - self.input_row_sum.astype("float32"), weight_dequant_scale.astype("float32"), int(self.TokenPadding), self.max_tokens, @@ -101,7 +102,7 @@ def test_w4afp8_gemm(self): ) gap = (out_cuda - out_naive).abs() - self.assertLess(float(gap.mean()), 0.07) + self.assertLess(float(gap.mean()), 0.11) if __name__ == "__main__": From 5d6303f7894b7dbcc6a447b9b55a5775cceb2f36 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Fri, 26 Sep 2025 15:10:55 +0800 Subject: [PATCH 4/7] revert append attn utils --- custom_ops/gpu_ops/append_attn/utils.cuh | 32 +++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/custom_ops/gpu_ops/append_attn/utils.cuh b/custom_ops/gpu_ops/append_attn/utils.cuh index eb7fb6b1ae..12d86dade8 100644 --- a/custom_ops/gpu_ops/append_attn/utils.cuh +++ b/custom_ops/gpu_ops/append_attn/utils.cuh @@ -404,9 +404,39 @@ __forceinline__ __host__ __device__ void vec_cast( } #define DISPATCH_GQA_GROUP_SIZE(group_size, GROUP_SIZE, ...) \ - if (group_size == 8) { \ + if (group_size == 1) { \ + constexpr size_t GROUP_SIZE = 1; \ + __VA_ARGS__ \ + } else if (group_size == 2) { \ + constexpr size_t GROUP_SIZE = 2; \ + __VA_ARGS__ \ + } else if (group_size == 3) { \ + constexpr size_t GROUP_SIZE = 3; \ + __VA_ARGS__ \ + } else if (group_size == 4) { \ + constexpr size_t GROUP_SIZE = 4; \ + __VA_ARGS__ \ + } else if (group_size == 5) { \ + constexpr size_t GROUP_SIZE = 5; \ + __VA_ARGS__ \ + } else if (group_size == 6) { \ + constexpr size_t GROUP_SIZE = 6; \ + __VA_ARGS__ \ + } else if (group_size == 7) { \ + constexpr size_t GROUP_SIZE = 7; \ + __VA_ARGS__ \ + } else if (group_size == 8) { \ constexpr size_t GROUP_SIZE = 8; \ __VA_ARGS__ \ + } else if (group_size == 12) { \ + constexpr size_t GROUP_SIZE = 12; \ + __VA_ARGS__ \ + } else if (group_size == 14) { \ + constexpr size_t GROUP_SIZE = 14; \ + __VA_ARGS__ \ + } else if (group_size == 16) { \ + constexpr size_t GROUP_SIZE = 16; \ + __VA_ARGS__ \ } else { \ PD_THROW("not support the group_size", group_size); \ } From 7bcccaab20391573da2c3b34eca64e25e32c0299 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Sun, 28 Sep 2025 16:34:17 +0800 Subject: [PATCH 5/7] =?UTF-8?q?ffn1=20=E5=8A=A8=E6=80=81=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_ops/gpu_ops/cpp_extensions.cc | 1 + custom_ops/gpu_ops/moe/fused_moe_helper.h | 2 +- custom_ops/gpu_ops/moe/fused_moe_op.h | 80 ++++++++++++++----- custom_ops/gpu_ops/moe/moe_dispatch.cu | 48 ++++++++--- custom_ops/gpu_ops/moe/moe_ffn.cu | 17 +++- custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h | 55 +++++++++---- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu | 12 ++- custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h | 3 +- .../w4afp8_gemm/w4afp8_gemm_kernel.hpp | 24 +++++- .../utils/auto_gen_w4afp8_gemm_kernel.py | 4 +- .../layers/moe/fused_moe_cutlass_backend.py | 20 +++-- 11 files changed, 201 insertions(+), 65 deletions(-) diff --git a/custom_ops/gpu_ops/cpp_extensions.cc b/custom_ops/gpu_ops/cpp_extensions.cc index 85f88cf123..de99799350 100644 --- a/custom_ops/gpu_ops/cpp_extensions.cc +++ b/custom_ops/gpu_ops/cpp_extensions.cc @@ -249,6 +249,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, diff --git a/custom_ops/gpu_ops/moe/fused_moe_helper.h b/custom_ops/gpu_ops/moe/fused_moe_helper.h index 703a7c11f0..f24f12ea70 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_helper.h +++ b/custom_ops/gpu_ops/moe/fused_moe_helper.h @@ -250,7 +250,7 @@ template class MoeHelper { initialize_moe_routing_kernelLauncher( input_activations, permuted_data_, permuted_rows_, nullptr, nullptr, - expanded_source_row_to_expanded_dest_row, num_rows, num_rows, + expanded_source_row_to_expanded_dest_row, nullptr, num_rows, num_rows, hidden_size, k, stream); const int64_t expanded_active_expert_rows = k * num_rows; diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index 6b743aa79c..f8e755ef15 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -128,6 +128,17 @@ struct SumOp { __device__ __forceinline__ T operator()(T const& x, T const& y) { return x + y; } }; +template +struct MaxOp { +__device__ inline T operator()(T const & x, T const & y) { return x > y ? x : y; } +}; + +template <> +struct MaxOp { +__device__ inline float operator()(float const &x, float const &y) { return fmax(x, y); } +}; + + template __forceinline__ __device__ OutType QuantHelperFunc(const InType input, const float scale, @@ -1145,7 +1156,7 @@ void topk_gating_softmax_kernelLauncher(const T* input, // to row 0 in the original matrix. Thus, to know where to read in the source // matrix, we simply take the modulus of the expanded index. -template +template __global__ void initialize_moe_routing_kernel( const T* unpermuted_input, OutT* permuted_output, @@ -1153,6 +1164,7 @@ __global__ void initialize_moe_routing_kernel( const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, + float *dequant_scale, const int64_t num_rows, const int64_t active_rows, const int64_t cols, @@ -1174,15 +1186,49 @@ __global__ void initialize_moe_routing_kernel( expanded_dest_row; } + extern __shared__ char smem_[]; + + T * data_smem = reinterpret_cast(smem_); + if (expanded_dest_row < active_rows) { const int expert_idx = expert_idx_per_token[expanded_dest_row]; - const float scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1; + float scale; const int source_row = expanded_source_row % num_rows; const T* source_row_ptr = unpermuted_input + source_row * cols; OutT *dest_row_ptr = permuted_output + expanded_dest_row * cols; + if constexpr(std::is_same::value) { + if (dequant_scale != nullptr) { + float abs_max = 0.f; + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + Load(&source_row_ptr[tid], &src_vec); + Store(src_vec, &data_smem[tid]); + for (int j = 0; j < VecSize; j++) { + abs_max = fmaxf(abs_max, fabsf(static_cast(src_vec[j]))); + } + } + abs_max = BlockAllReduce(abs_max); + scale = 440.0f / abs_max; + dequant_scale[expanded_dest_row] = abs_max; + for (int tid = threadIdx.x * VecSize; tid < cols; + tid += blockDim.x * VecSize) { + Load(&data_smem[tid], &src_vec); + using StoreT = AlignedVector; + StoreT dest_vec; + for (int j = 0; j < VecSize; j++) { + float quant_value = scale * static_cast(src_vec[j]); + dest_vec[j] = static_cast(quant_value); + } + Store(dest_vec, &dest_row_ptr[tid]); + } + return; + } else { + scale = w4a8_in_scale ? w4a8_in_scale[expert_idx] : -1; + } + } for (int tid = threadIdx.x * VecSize; tid < cols; tid += blockDim.x * VecSize) { // dest_row_ptr[tid] = source_row_ptr[tid]; @@ -1228,41 +1274,35 @@ void initialize_moe_routing_kernelLauncher( const int *expert_idx_per_token, const float *w4a8_in_scale, int* expanded_source_row_to_expanded_dest_row, + float * dequant_scale, const int64_t num_rows, const int64_t active_rows, const int64_t cols, const int64_t k, cudaStream_t stream) { - const int threads = std::min(cols, int64_t(1024)); + constexpr int threads = 256; constexpr int max_pack_size = 16 / sizeof(T); const auto config_initialize = Get1DBlocksAnd2DGridsMoe(num_rows * k); - if (cols % max_pack_size == 0) { - initialize_moe_routing_kernel - <<>>( - unpermuted_input, - permuted_output, - expanded_dest_row_to_expanded_source_row, - expert_idx_per_token, - w4a8_in_scale, - expanded_source_row_to_expanded_dest_row, - num_rows, - k * active_rows, - cols, - num_rows * k); - } else { - initialize_moe_routing_kernel - <<>>( + const int smem_size = cols * sizeof(float); + auto kernel = &initialize_moe_routing_kernel; + if (cols % max_pack_size != 0) { + kernel = &initialize_moe_routing_kernel; + } + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>( unpermuted_input, permuted_output, expanded_dest_row_to_expanded_source_row, expert_idx_per_token, w4a8_in_scale, expanded_source_row_to_expanded_dest_row, + dequant_scale, num_rows, k * active_rows, cols, num_rows * k); - } } // ============================== Infer GEMM sizes diff --git a/custom_ops/gpu_ops/moe/moe_dispatch.cu b/custom_ops/gpu_ops/moe/moe_dispatch.cu index bc18ece456..495b522250 100644 --- a/custom_ops/gpu_ops/moe/moe_dispatch.cu +++ b/custom_ops/gpu_ops/moe/moe_dispatch.cu @@ -33,7 +33,7 @@ void MoeDispatchKernel( const int hidden_size, const int expert_num, paddle::Tensor *permute_input, paddle::Tensor *tokens_expert_prefix_sum, paddle::Tensor *permute_indices_per_token, paddle::Tensor *topk_weight, - paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token) { + paddle::Tensor *topk_idx, paddle::Tensor *expert_idx_per_token, paddle::Tensor *dequant_scale) { using namespace phi; if (num_rows == 0){ @@ -120,22 +120,34 @@ void MoeDispatchKernel( initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), - permute_indices_per_token->data(), num_rows, num_rows, + permute_indices_per_token->data(), nullptr, + num_rows, num_rows, hidden_size, moe_topk, stream); } else if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { initialize_moe_routing_kernelLauncher( input.data(), permute_input->data(), permuted_rows_, expert_idx_per_token->data(), w4a8_in_scale->data(), - permute_indices_per_token->data(), num_rows, num_rows, + permute_indices_per_token->data(), nullptr, + num_rows, num_rows, hidden_size, moe_topk, stream); } } else { - initialize_moe_routing_kernelLauncher( - input.data(), permute_input->data(), permuted_rows_, - expert_idx_per_token->data(), nullptr, - permute_indices_per_token->data(), num_rows, num_rows, + if (permute_input->dtype() == paddle::DataType::FLOAT8_E4M3FN) { + initialize_moe_routing_kernelLauncher( + input.data(), permute_input->data(), + permuted_rows_, expert_idx_per_token->data(), + nullptr, + permute_indices_per_token->data(), dequant_scale->data(), + num_rows, num_rows, hidden_size, moe_topk, stream); + } else { + initialize_moe_routing_kernelLauncher( + input.data(), permute_input->data(), permuted_rows_, + expert_idx_per_token->data(), nullptr, + permute_indices_per_token->data(), nullptr, num_rows, num_rows, + hidden_size, moe_topk, stream); + } } compute_total_rows_before_expert( @@ -170,10 +182,20 @@ std::vector MoeExpertDispatch( } else if (moe_quant_type == "w4afp8") { permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; } + } else { + if (moe_quant_type == "w4afp8") { + permute_input_dtype = paddle::DataType::FLOAT8_E4M3FN; + } } auto permute_input = GetEmptyTensor({moe_topk * num_rows, hidden_size}, permute_input_dtype, place); + int dequant_scale_size = 1; + if (moe_quant_type == "w4afp8" && !w4a8_in_scale) { + dequant_scale_size = moe_topk * num_rows; + } + + auto dequant_scale = GetEmptyTensor({dequant_scale_size}, paddle::DataType::FLOAT32, place); // correspond to the weighted coefficients of the results from each expert. auto topk_weight = GetEmptyTensor({num_rows, moe_topk}, paddle::DataType::FLOAT32, place); @@ -194,7 +216,8 @@ std::vector MoeExpertDispatch( permute_indices_per_token, topk_weight, topk_idx, - expert_idx_per_token}; + expert_idx_per_token, + dequant_scale}; } switch (input_type) { @@ -203,14 +226,14 @@ std::vector MoeExpertDispatch( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, - &topk_weight, &topk_idx, &expert_idx_per_token); + &topk_weight, &topk_idx, &expert_idx_per_token, &dequant_scale); break; case paddle::DataType::FLOAT16: MoeDispatchKernel( input, gating_output, gating_correction_bias, w4a8_in_scale, moe_topk, group_moe, topk_only_mode, num_rows, hidden_size, expert_num, &permute_input, &tokens_expert_prefix_sum, &permute_indices_per_token, - &topk_weight, &topk_idx, &expert_idx_per_token); + &topk_weight, &topk_idx, &expert_idx_per_token, &dequant_scale); break; default: PD_THROW("Unsupported data type for MoeDispatchKernel"); @@ -220,7 +243,8 @@ std::vector MoeExpertDispatch( permute_indices_per_token, topk_weight, topk_idx, - expert_idx_per_token}; + expert_idx_per_token, + dequant_scale}; } std::vector> MoeExpertDispatchInferShape( @@ -311,7 +335,7 @@ PD_BUILD_STATIC_OP(moe_expert_dispatch) paddle::Optional("w4a8_in_scale")}) .Outputs({"permute_input", "tokens_expert_prefix_sum", "permute_indices_per_token", "topk_weight", "topk_idx", - "expert_idx_per_token"}) + "expert_idx_per_token", "dequant_scale"}) .Attrs({"moe_topk:int", "group_moe:bool", "moe_quant_type:std::string", "topk_only_mode:bool"}) .SetKernelFn(PD_KERNEL(MoeExpertDispatch)) .SetInferShapeFn(PD_INFER_SHAPE(MoeExpertDispatchInferShape)) diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index fedb8b2c97..66b2054cfb 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -27,6 +27,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -180,12 +181,12 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, typedef typename traits_fp8::data_t data_t_fp8; paddle::Tensor weight_scale_tensor = *const_cast(up_gate_proj_scale.get_ptr()); const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? hidden_size : weight_scale_tensor.dims()[3]; - float* row_scale = nullptr; + const float* input_dequant_scale = up_proj_in_scale ? up_proj_in_scale.get().data() : nullptr; DisPatchW4AFp8GemmWrapper( reinterpret_cast(permute_input.data()), reinterpret_cast(up_gate_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - row_scale, + input_dequant_scale, weight_scale_tensor.data(), reinterpret_cast(fc1_out), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, @@ -304,7 +305,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, } else if (quant_method == "w4afp8") { data_t *ffn2_shift = nullptr; data_t *ffn2_smooth = nullptr; - float* row_scale = nullptr; + float* input_dequant_scale = nullptr; Allocator::AllocationPtr fp8_act_out; fp8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); @@ -335,7 +336,7 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, reinterpret_cast(fp8_act_out->ptr()), reinterpret_cast(down_proj_weight.data()), const_cast(tokens_expert_prefix_sum.data()), - row_scale, + input_dequant_scale, weight_scale_tensor.data(), reinterpret_cast(ffn_out_data), used_in_ep_low_latency ? num_max_tokens_per_expert : 0, @@ -368,6 +369,7 @@ paddle::Tensor MoeExpertFFNFunc( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -389,6 +391,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -405,6 +408,7 @@ const auto t_type = (quant_method == "w4a8") ? up_gate_proj_scale.get().dtype() tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -427,6 +431,7 @@ std::vector MoeExpertFFN( const paddle::Tensor& tokens_expert_prefix_sum, const paddle::Tensor& up_gate_proj_weight, const paddle::Tensor& down_proj_weight, + const paddle::optional& up_proj_in_scale, const paddle::optional& up_gate_proj_bias, const paddle::optional& up_gate_proj_scale, const paddle::optional& down_proj_scale, @@ -439,6 +444,7 @@ std::vector MoeExpertFFN( tokens_expert_prefix_sum, up_gate_proj_weight, down_proj_weight, + up_proj_in_scale, up_gate_proj_bias, up_gate_proj_scale, down_proj_scale, @@ -455,6 +461,7 @@ std::vector> MoeExpertFFNInferShape( const std::vector& tokens_expert_prefix_sum_shape, const std::vector& up_gate_proj_weight_shape, const std::vector& down_proj_weight_shape, + const paddle::optional>& up_proj_in_scale_shape, const paddle::optional>& up_gate_proj_bias_shape, const paddle::optional>& up_gate_proj_scale_shape, const paddle::optional>& down_proj_scale_shape, @@ -472,6 +479,7 @@ std::vector MoeExpertFFNInferDtype( const paddle::DataType &tokens_expert_prefix_sum_dtype, const paddle::DataType &up_gate_proj_weight_dtype, const paddle::DataType &down_proj_weight_dtype, + const paddle::optional &up_proj_in_scale_dtype, const paddle::optional &up_gate_proj_bias_dtype, const paddle::optional &up_gate_proj_scale_dtype, const paddle::optional &down_proj_scale_dtype, @@ -545,6 +553,7 @@ PD_BUILD_STATIC_OP(moe_expert_ffn) "tokens_expert_prefix_sum", "up_gate_proj_weight", "down_proj_weight", + paddle::Optional("up_proj_in_scale"), paddle::Optional("up_gate_proj_bias"), paddle::Optional("up_gate_proj_scale"), paddle::Optional("down_proj_scale"), diff --git a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h index 2050bf862f..b650371415 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/mainloop_fwd.h @@ -119,6 +119,7 @@ struct CollectiveMainloopFwd { LayoutT layout_C; const float *weight_scale; LayoutTScale layout_Scale; + const float *input_scale; const int64_t * tokens; }; @@ -131,6 +132,7 @@ struct CollectiveMainloopFwd { TMA_Scale tma_load_Scale; ElementOutput * ptr_C; const float *weight_scale; + const float *input_scale; const int64_t * tokens; }; @@ -162,7 +164,7 @@ struct CollectiveMainloopFwd { return { args.layout_A, args.layout_B, args.layout_Scale, tma_load_A, tma_load_B, tma_load_Scale, - args.ptr_C, args.weight_scale, args.tokens}; + args.ptr_C, args.weight_scale, args.input_scale, args.tokens}; } CUTLASS_DEVICE @@ -181,6 +183,7 @@ struct CollectiveMainloopFwd { SharedStorage& shared_storage, TiledMma tiled_mma, const float *weight_scale, + const float *input_scale, const int64_t tokens, const int64_t pre_fix_tokens, const int bidm, @@ -191,24 +194,48 @@ struct CollectiveMainloopFwd { using packHalf = typename PackedHalf::Type; Tensor tOrO_out = make_tensor(tOrO.layout()); - if constexpr (WeightScaleGroup == K) { - #pragma unroll - for (int i = 0; i < size(tOrO); i+=4) { - tOrO[i] = (tOrO[i]) * weight_scale[0]; - tOrO[i + 1] = tOrO[i + 1] * weight_scale[0]; - tOrO[i + 2] = tOrO[i + 2] * weight_scale[1]; - tOrO[i + 3] = tOrO[i + 3] * weight_scale[1]; - *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); - *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + if (input_scale != nullptr) { + const int lane_id = tidx % 4 * 2; + if constexpr (WeightScaleGroup == K) { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + const int scale_idx = i * 2 + lane_id; + tOrO[i] = tOrO[i] * weight_scale[0] * input_scale[scale_idx]; + tOrO[i + 1] = tOrO[i + 1] * weight_scale[0] * input_scale[scale_idx + 1]; + tOrO[i + 2] = tOrO[i + 2] * weight_scale[1] * input_scale[scale_idx]; + tOrO[i + 3] = tOrO[i + 3] * weight_scale[1] * input_scale[scale_idx + 1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + } + } else { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + const int scale_idx = i * 2 + lane_id; + *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]) * input_scale[scale_idx], float(tOrO[i + 2]) * input_scale[scale_idx]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]) * input_scale[scale_idx + 1], float(tOrO[i + 3]) * input_scale[scale_idx + 1]); + } } } else { - #pragma unroll - for (int i = 0; i < size(tOrO); i+=4) { - *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]), float(tOrO[i + 2])); - *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]), float(tOrO[i + 3])); + if constexpr (WeightScaleGroup == K) { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + tOrO[i] = (tOrO[i]) * weight_scale[0]; + tOrO[i + 1] = tOrO[i + 1] * weight_scale[0]; + tOrO[i + 2] = tOrO[i + 2] * weight_scale[1]; + tOrO[i + 3] = tOrO[i + 3] * weight_scale[1]; + *reinterpret_cast(&tOrO_out[i]) = packHalf(tOrO[i], tOrO[i + 2]); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(tOrO[i + 1], tOrO[i + 3]); + } + } else { + #pragma unroll + for (int i = 0; i < size(tOrO); i+=4) { + *reinterpret_cast(&tOrO_out[i]) = packHalf(float(tOrO[i]), float(tOrO[i + 2])); + *reinterpret_cast(&tOrO_out[i + 2]) = packHalf(float(tOrO[i + 1]), float(tOrO[i + 3])); + } } } + uint16_t *smem_c = reinterpret_cast(shared_storage.smem_c.data()); uint32_t * reg_data = reinterpret_cast(tOrO_out.data()); diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu index 1d307bbd7c..0b72b3454f 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.cu @@ -52,6 +52,7 @@ void DisPatchW4AFp8Gemm( const cutlass::float_e4m3_t* weight, const int64_t * tokens, const float * weight_scale, + const float * input_dequant_scale, OutputType * out, const int64_t token_padding_size, const int64_t max_tokens, @@ -69,6 +70,7 @@ void DisPatchW4AFp8Gemm( input, out, weight_scale, + input_dequant_scale, tokens, max_tokens, stream) @@ -82,6 +84,7 @@ std::vector W4AFp8Gemm( const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group const paddle::Tensor& weight_scale, + const paddle::Tensor& input_dequant_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16) { @@ -106,6 +109,7 @@ std::vector W4AFp8Gemm( reinterpret_cast(weight.data()), tokens.data(), weight_scale.data(), + input_dequant_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, @@ -127,6 +131,7 @@ std::vector W4AFp8Gemm( reinterpret_cast(weight.data()), tokens.data(), weight_scale.data(), + input_dequant_scale.data(), reinterpret_cast(out_data), token_padding_size, max_tokens, @@ -147,7 +152,7 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t* total_rows_before_expert, - const float* row_scale, + const float* input_dequant_scale, const float* weight_scale, OutputType * out, const int64_t token_padding_size, @@ -164,6 +169,7 @@ void DisPatchW4AFp8GemmWrapper( reinterpret_cast(weight), total_rows_before_expert, weight_scale, + input_dequant_scale, reinterpret_cast(out), token_padding_size, max_tokens, @@ -200,7 +206,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, __nv_bfloat16>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, __nv_bfloat16 * out, const int64_t token_padding_size, @@ -216,7 +222,7 @@ template void DisPatchW4AFp8GemmWrapper<__nv_fp8_e4m3, half>( const __nv_fp8_e4m3* input, const __nv_fp8_e4m3* weight, const int64_t * tokens, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, half * out, const int64_t token_padding_size, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h index 1c1db0e12c..b8c393ae1d 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm.h @@ -25,6 +25,7 @@ std::vector W4AFp8Gemm( const paddle::Tensor& weight, const paddle::Tensor& tokens, // If tokenpadding=0, this tensor represents the prefix sum of tensors, otherwise it represents the number of tokens in each group const paddle::Tensor& weight_scale, + const paddle::Tensor& input_dequant_scale, const int64_t token_padding_size, const int64_t max_tokens, const bool is_bfloat16); @@ -34,7 +35,7 @@ void DisPatchW4AFp8GemmWrapper( const InputType* input, const InputType* weight, const int64_t * tokens, - const float * row_scale, + const float * input_dequant_scale, const float * weight_scale, OutputType * out, const int64_t token_padding_size, diff --git a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp index 1c7250f6e1..425d648e61 100644 --- a/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp +++ b/custom_ops/gpu_ops/w4afp8_gemm/w4afp8_gemm_kernel.hpp @@ -100,6 +100,10 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp return; } + const bool is_need_input_scale = mainloop_params.input_scale != nullptr; + + float* input_scale = is_need_input_scale ? reinterpret_cast(shared_memory + sizeof(typename Ktraits::SharedStorage)) : nullptr; + if (warp_group_idx == 0) { cutlass::arch::warpgroup_reg_dealloc(); PipelineState smem_pipe_write = cutlass::make_producer_start_state(); @@ -122,6 +126,20 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp const int mma_tidx = tidx - NumCopyThreads; + if (is_need_input_scale) { + if constexpr (TokenPackSize == 0) { + const int input_scale_idx = pre_fix_tokens + bidn * kBlockN; + if (mma_tidx < tokens) { + reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; + } + } else { + const int input_scale_idx = bidb * TokenPackSize + bidn * kBlockN; + if (mma_tidx < kBlockN / 4) { + reinterpret_cast(input_scale)[mma_tidx] = reinterpret_cast(mainloop_params.input_scale + input_scale_idx)[mma_tidx]; + } + } + } + float2 weight_scale; if constexpr (WeightScaleGroup == K) { @@ -156,6 +174,7 @@ void __global__ __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp shared_storage, tiled_mma, reinterpret_cast(&weight_scale), + input_scale, tokens, pre_fix_tokens, bidm, @@ -194,7 +213,7 @@ auto get_scale_layout(const int Rows, const int Cols) { template -void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, const int64_t * tokens, const int max_tokens, cudaStream_t stream) { +void run_gemm(const InputType * A, const InputType * B, OutputType * C, const float *weight_scale, const float * input_dequant_scale, const int64_t * tokens, const int max_tokens, cudaStream_t stream) { using ElementOutput = typename Kernel_traits::ElementOutput; using Element = typename Kernel_traits::Element; @@ -217,13 +236,14 @@ void run_gemm(const InputType * A, const InputType * B, OutputType * C, const fl get_gmem_layout(M, TokenPackSize == 0 ? max_tokens : TokenPackSize), weight_scale, get_scale_layout(M_nums, K_scale_nums * Kernel_traits::kBlockM), + input_dequant_scale, tokens }); void *kernel; kernel = (void *)w4afp8_gemm_kernel; - int smem_size = sizeof(typename Kernel_traits::SharedStorage); + int smem_size = sizeof(typename Kernel_traits::SharedStorage) + Kernel_traits::kBlockN * sizeof(float); if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); diff --git a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py index 9cf502236e..802942eecd 100644 --- a/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py +++ b/custom_ops/utils/auto_gen_w4afp8_gemm_kernel.py @@ -37,6 +37,7 @@ const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, + const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream); @@ -54,6 +55,7 @@ const cutlass::float_e4m3_t * input, {cutlass_type} * out, const float *weight_scale, + const float * input_dequant_scale, const int64_t *tokens, const int64_t max_tokens, cudaStream_t stream) {{ @@ -78,7 +80,7 @@ {cutlass_type}>; run_gemm - (weight, input, out, weight_scale, tokens, max_tokens, stream); + (weight, input, out, weight_scale, input_dequant_scale, tokens, max_tokens, stream); }} """ diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index cc02bf7f18..d207bd690e 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -70,6 +70,7 @@ def compute_ffn( expert_idx_per_token: paddle.Tensor, used_in_ep_low_latency: bool = False, estimate_total_token_nums: int = -1, + dequant_scale: paddle.Tensor = None, ): """ Paddle Cutlass compute Fused MoE. @@ -93,6 +94,7 @@ def compute_ffn( token_nums_per_expert, getattr(layer, self.added_weight_attrs[0]), getattr(layer, self.added_weight_attrs[1]), + dequant_scale, None, (layer.up_gate_proj_weight_scale if hasattr(layer, "up_gate_proj_weight_scale") else None), (layer.down_proj_weight_scale if hasattr(layer, "down_proj_weight_scale") else None), @@ -267,13 +269,12 @@ def apply_tp( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out, layer.gate_correction_bias, - ( - layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None - ), # if set, permute_input will be int8_t + None, # if set, permute_input will be int8_t layer.top_k, False, self.moe_quant_type, @@ -287,7 +288,9 @@ def apply_tp( else: expert_idx_per_token = expert_idx_per_token.cast("int64") - ffn_out = self.compute_ffn(layer, permute_input, token_nums_per_expert, expert_idx_per_token) + ffn_out = self.compute_ffn( + layer, permute_input, token_nums_per_expert, expert_idx_per_token, False, -1, dequant_scale + ) # reduce 中会做 topk 个 weight 的 norm 和 routed_scaling_factor fused_moe_out = moe_expert_reduce( @@ -885,9 +888,12 @@ def _permute_weight_scale(weight_scale: paddle.Tensor): return weight_scale def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor): - processed_weight_scale = ( - paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] - ) + if name == "up_gate_proj_weight_scale": + processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) + else: + processed_weight_scale = ( + paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] + ) if len(processed_weight_scale.shape) == 3: if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size: From d49c35ef79270cbe0589a1e0aa8329171061ac60 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Mon, 29 Sep 2025 13:49:21 +0800 Subject: [PATCH 6/7] =?UTF-8?q?ffn2=20=E6=94=AF=E6=8C=81=E5=8A=A8=E6=80=81?= =?UTF-8?q?=E9=87=8F=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- custom_ops/gpu_ops/moe/fused_moe_op.h | 194 ++++++++++-------- custom_ops/gpu_ops/moe/moe_ffn.cu | 73 +++++-- .../layers/moe/fused_moe_cutlass_backend.py | 42 ++-- 3 files changed, 181 insertions(+), 128 deletions(-) diff --git a/custom_ops/gpu_ops/moe/fused_moe_op.h b/custom_ops/gpu_ops/moe/fused_moe_op.h index f8e755ef15..b3687ea9a7 100644 --- a/custom_ops/gpu_ops/moe/fused_moe_op.h +++ b/custom_ops/gpu_ops/moe/fused_moe_op.h @@ -150,101 +150,114 @@ __forceinline__ __device__ OutType QuantHelperFunc(const InType input, template __global__ void masked_quantize_moe_input_kernel(const T* permuted_inputs, -const int64_t* expert_idx_per_token, -const float* quant_scales, -const float quant_max_bound, -const float quant_min_bound, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert, -OutT* out) { -using LoadT = AlignedVector; -using LoadOutT = AlignedVector; -LoadT input_vec; -LoadOutT output_vec; -float scale_factor = -7.0f / 512.0f; -using vec_t = typename BytesToType::Type; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; - const auto expert_id = token_idx / num_max_tokens_per_expert; - if (token_idx_in_expert >= recv_expert_count[expert_id]) { - auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; - auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; - token_idx += num_iters_to_next_expert * gridDim.x; - continue; - } - int64_t expert_idx = expert_idx_per_token[token_idx]; - float quant_scale = quant_scales[expert_idx]; - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - output_vec[i] = QuantHelperFunc(input_vec[i], quant_scale, quant_max_bound, quant_min_bound); - thread_row_sum += static_cast(output_vec[i]); + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_dequant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + OutT* out) { + using LoadT = AlignedVector; + using LoadOutT = AlignedVector; + LoadT input_vec; + LoadOutT output_vec; + using vec_t = typename BytesToType::Type; + extern __shared__ char smem_[]; + for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { + const auto token_idx_in_expert = token_idx % num_max_tokens_per_expert; + const auto expert_id = token_idx / num_max_tokens_per_expert; + if (token_idx_in_expert >= recv_expert_count[expert_id]) { + auto next_expert_start_idx = (expert_id + 1) * num_max_tokens_per_expert; + auto num_iters_to_next_expert = (next_expert_start_idx - token_idx - 1) / gridDim.x; + token_idx += num_iters_to_next_expert * gridDim.x; + continue; + } + int64_t expert_idx = expert_idx_per_token[token_idx]; + float abs_max = 0.0f; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + abs_max = fmax(abs_max, fabs(res)); + } + Store(input_vec, reinterpret_cast(smem_) + idx * VecSize); + } + abs_max = BlockAllReduce(abs_max); + input_dequant_scale[token_idx] = abs_max; + float quant_scale = 440.0f / abs_max; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(reinterpret_cast(smem_) + idx * VecSize, &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + output_vec[i] = static_cast(res * quant_scale); + } + *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); + } } - *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } } template __global__ void quantize_moe_input_kernel(const T* permuted_inputs, -const int64_t* expert_idx_per_token, -const float* quant_scales, -const float quant_max_bound, -const float quant_min_bound, -const int64_t token_num, -const int64_t dim, -float* permuted_input_row_sum, -const int64_t* recv_expert_count, -const int num_max_tokens_per_expert, -OutT* out) { -using LoadT = AlignedVector; -using LoadOutT = AlignedVector; -LoadT input_vec; -LoadOutT output_vec; -using vec_t = typename BytesToType::Type; -float scale_factor = -7.0f / 512.0f; -for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { - int64_t expert_idx = expert_idx_per_token[token_idx]; - float quant_scale = quant_scales[expert_idx]; - float thread_row_sum = 0.0f; - for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { - int64_t offset = token_idx * dim + idx * VecSize; - Load(&permuted_inputs[offset], &input_vec); - #pragma unroll - for (int i = 0; i < VecSize; i++) { - output_vec[i] = QuantHelperFunc(input_vec[i], quant_scale, quant_max_bound, quant_min_bound); - thread_row_sum += static_cast(output_vec[i]); + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_dequant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + OutT* out) { + using LoadT = AlignedVector; + using LoadOutT = AlignedVector; + LoadT input_vec; + LoadOutT output_vec; + using vec_t = typename BytesToType::Type; + + extern __shared__ char smem_[]; + + for (int token_idx = blockIdx.x; token_idx < token_num; token_idx += gridDim.x) { + int64_t expert_idx = expert_idx_per_token[token_idx]; + float abs_max = 0.0f; + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(&permuted_inputs[offset], &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + abs_max = fmax(abs_max, fabs(res)); + } + Store(input_vec, reinterpret_cast(smem_) + idx * VecSize); + } + abs_max = BlockAllReduce(abs_max); + input_dequant_scale[token_idx] = abs_max; + float quant_scale = 440.0f / abs_max; + + for(int idx = threadIdx.x; idx < dim / VecSize; idx += blockDim.x) { + int64_t offset = token_idx * dim + idx * VecSize; + Load(reinterpret_cast(smem_) + idx * VecSize, &input_vec); + #pragma unroll + for (int i = 0; i < VecSize; i++) { + float res = static_cast(input_vec[i]); + output_vec[i] = static_cast(res * quant_scale); + } + *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); + } } - *(reinterpret_cast(&out[offset])) = *(reinterpret_cast(&output_vec)); - } - float block_row_sum = BlockAllReduce(thread_row_sum); - permuted_input_row_sum[token_idx] = block_row_sum * scale_factor; - } } template void quantize_moe_input( - const T* permuted_inputs, - const int64_t* expert_idx_per_token, - const float* quant_scales, - const float quant_max_bound, - const float quant_min_bound, - const int64_t token_num, - const int64_t dim, - float* permuted_input_row_sum, - const int64_t* recv_expert_count, - const int num_max_tokens_per_expert, - bool used_in_ep_low_latency, - OutT* out, - cudaStream_t stream) { + const T* permuted_inputs, + const int64_t* expert_idx_per_token, + const int64_t token_num, + const int64_t dim, + float* input_quant_scale, + const int64_t* recv_expert_count, + const int num_max_tokens_per_expert, + bool used_in_ep_low_latency, + OutT* out, + cudaStream_t stream) { constexpr int VecSize = 16 / sizeof(T); constexpr int threads_per_block = 128; const int dev_id = 0; @@ -258,15 +271,16 @@ void quantize_moe_input( const int num_blocks_per_wave = sm_count * act_blocks_per_sm; dim3 grid; grid.x = min(static_cast(num_blocks_per_wave), token_num); - kernel<<>>( + const int smem_size = dim * sizeof(T); + if (smem_size >= 48 * 1024) { + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + } + kernel<<>>( permuted_inputs, expert_idx_per_token, - quant_scales, - quant_max_bound, - quant_min_bound, token_num, dim, - permuted_input_row_sum, + input_quant_scale, recv_expert_count, num_max_tokens_per_expert, out); diff --git a/custom_ops/gpu_ops/moe/moe_ffn.cu b/custom_ops/gpu_ops/moe/moe_ffn.cu index 66b2054cfb..fe48861bf7 100644 --- a/custom_ops/gpu_ops/moe/moe_ffn.cu +++ b/custom_ops/gpu_ops/moe/moe_ffn.cu @@ -310,24 +310,61 @@ void MoeFFNKernel(const paddle::Tensor& permute_input, fp8_act_out = allocator->Allocate( SizeOf(paddle::DataType::INT8) * act_out_tensor.numel()); - MoeFastHardamardWrapper( - act_out_tensor.data(), - expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, - const_cast(tokens_expert_prefix_sum.data()), - ffn2_shift, - ffn2_smooth, - down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, - 1, - 448.0f, - -448.0f, - expanded_active_expert_rows, - inter_size / 2, - num_max_tokens_per_expert, - used_in_ep_low_latency, - hadamard_block_size, - reinterpret_cast(fp8_act_out->ptr()), - stream - ); + if (down_proj_in_scale) { + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + ffn2_shift, + ffn2_smooth, + down_proj_in_scale ? const_cast(down_proj_in_scale.get_ptr())->data() : nullptr, + 1, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + reinterpret_cast(fp8_act_out->ptr()), + stream + ); + } else { + Allocator::AllocationPtr ffn2_input_dequant_scale; + ffn2_input_dequant_scale = allocator->Allocate( + sizeof(float) * expanded_active_expert_rows); + input_dequant_scale = reinterpret_cast(ffn2_input_dequant_scale->ptr()); + MoeFastHardamardWrapper( + act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + const_cast(tokens_expert_prefix_sum.data()), + ffn2_shift, // ffn2_shift->data(), + ffn2_smooth, // ffn2_smooth->data(), + nullptr, + 1, + 448.0f, + -448.0f, + expanded_active_expert_rows, + inter_size / 2, + num_max_tokens_per_expert, + used_in_ep_low_latency, + hadamard_block_size, + act_out_tensor.data(), + stream + ); + + quantize_moe_input(act_out_tensor.data(), + expert_idx_per_token ? expert_idx_per_token.get().data() : nullptr, + expanded_active_expert_rows, + inter_size / 2, + input_dequant_scale, + const_cast(tokens_expert_prefix_sum.data()), + num_max_tokens_per_expert, + used_in_ep_low_latency, + reinterpret_cast(fp8_act_out->ptr()), + stream + ); + } paddle::Tensor weight_scale_tensor = *const_cast(down_proj_scale.get_ptr()); const int weight_scale_group_size = weight_scale_tensor.dims().size() == 2 ? inter_size / 2 : weight_scale_tensor.dims()[3]; diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index d207bd690e..2325e31934 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -274,13 +274,16 @@ def apply_tp( x, gate_out, layer.gate_correction_bias, - None, # if set, permute_input will be int8_t + (layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None), layer.top_k, False, self.moe_quant_type, topk_only_mode=False, ) + if hasattr(layer, "up_gate_proj_in_scale"): + dequant_scale = None + if self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8": # only w4a8 need expert_idx_per_token # Other need not this tensor, so we make it None. @@ -821,16 +824,16 @@ def create_w4afp8_scale_weights(self, layer: nn.Layer, weight_key_map: dict): ) # in_scales - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: - setattr( - layer, - in_scale_name, - layer.create_parameter( - shape=[layer.num_local_experts], - dtype="float32", - default_initializer=paddle.nn.initializer.Constant(0), - ), - ) + # for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: + # setattr( + # layer, + # in_scale_name, + # layer.create_parameter( + # shape=[layer.num_local_experts], + # dtype="float32", + # default_initializer=paddle.nn.initializer.Constant(0), + # ), + # ) # weight_scales setattr( @@ -888,12 +891,12 @@ def _permute_weight_scale(weight_scale: paddle.Tensor): return weight_scale def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], processed_in_scale: paddle.Tensor): - if name == "up_gate_proj_weight_scale": - processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) - else: + if processed_in_scale is not None: processed_weight_scale = ( paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) / processed_in_scale[:, None] ) + else: + processed_weight_scale = paddle.stack(weight_scales, axis=0) / (448 * 7 * 2 ** (-9)) if len(processed_weight_scale.shape) == 3: if name == "up_gate_proj_weight_scale" and processed_weight_scale.shape[-1] * 128 != layer.hidden_size: @@ -980,16 +983,15 @@ def _process_weight_scale(name: str, weight_scales: list[paddle.Tensor], process scale_tensor = _extract_scale_tensor(layer, state_dict, scale_key_template, expert_idx) scale_weight_map[name].append(scale_tensor) - # 3. Process scale tensor and set to layer - in_scales = [] - for in_scale_name in ["up_gate_proj_in_scale", "down_proj_in_scale"]: - in_scales.append(_process_in_scale(in_scale_name, scale_weight_map[in_scale_name])) - for i, weight_scale_name in enumerate(["up_gate_proj_weight_scale", "down_proj_weight_scale"]): + in_scale_name = weight_scale_name.replace("_weight_scale", "_in_scale") + in_scale = None + if hasattr(layer, in_scale_name) and in_scale_name in scale_weight_map.keys(): + in_scale = _process_in_scale(in_scale_name, scale_weight_map[in_scale_name]) _process_weight_scale( weight_scale_name, scale_weight_map[weight_scale_name], - in_scales[i], + in_scale, ) From 44ca44ee126d3946fbbefcc98ce918aa93f23792 Mon Sep 17 00:00:00 2001 From: yangjianfengo1 Date: Mon, 29 Sep 2025 13:59:16 +0800 Subject: [PATCH 7/7] code style --- .../model_executor/layers/moe/fused_moe_cutlass_backend.py | 1 + fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py | 1 + 2 files changed, 2 insertions(+) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py index 2325e31934..c36536f013 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py @@ -249,6 +249,7 @@ def apply_tp( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out, diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py index f9f717d313..f390749e47 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_wint2_backend.py @@ -275,6 +275,7 @@ def apply( topk_weights, topk_idx, expert_idx_per_token, + dequant_scale, ) = moe_expert_dispatch( x, gate_out,