diff --git a/cpp/tensorrt_llm/CMakeLists.txt b/cpp/tensorrt_llm/CMakeLists.txt index 7ebf07122f1..a9e4a007290 100644 --- a/cpp/tensorrt_llm/CMakeLists.txt +++ b/cpp/tensorrt_llm/CMakeLists.txt @@ -189,6 +189,7 @@ set(TRTLLM_LINK_LIBS fb_gemm_src gemm_swiglu_sm90_src cutlass_src + cute_dsl_src layers_src runtime_src testing_src diff --git a/cpp/tensorrt_llm/kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/CMakeLists.txt index 74680318170..7cf669de18b 100644 --- a/cpp/tensorrt_llm/kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/CMakeLists.txt @@ -22,6 +22,8 @@ file(GLOB_RECURSE SRC_CU *.cu) # selectiveScan trtllmGenKernels folder list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*") list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*") +list(FILTER SRC_CPP EXCLUDE REGEX "cuteDslKernels/.*") +list(FILTER SRC_CU EXCLUDE REGEX "cuteDslKernels/.*") list(FILTER SRC_CPP EXCLUDE REGEX "flashMLA/.*") list(FILTER SRC_CU EXCLUDE REGEX "flashMLA/.*") list(FILTER SRC_CPP EXCLUDE REGEX "contextFusedMultiHeadAttention/.*") @@ -75,6 +77,7 @@ target_include_directories( add_cuda_architectures(kernels_src 89) add_subdirectory(cutlass_kernels) +add_subdirectory(cuteDslKernels) add_subdirectory(flashMLA) add_subdirectory(contextFusedMultiHeadAttention) add_subdirectory(decoderMaskedMultiheadAttention) diff --git a/cpp/tensorrt_llm/kernels/cuteDslKernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cuteDslKernels/CMakeLists.txt new file mode 100644 index 00000000000..a718c76c076 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cuteDslKernels/CMakeLists.txt @@ -0,0 +1,23 @@ +# +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# + +file(GLOB_RECURSE SRC_CPP *.cpp) +file(GLOB_RECURSE SRC_CU *.cu) + +add_library(cute_dsl_src OBJECT ${SRC_CPP} ${SRC_CU}) +set_property(TARGET cute_dsl_src PROPERTY POSITION_INDEPENDENT_CODE ON) +set_property(TARGET cute_dsl_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON) diff --git a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu new file mode 100644 index 00000000000..32a54662ff1 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu @@ -0,0 +1,439 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/kernels/cuteDslKernels/moeUtils.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh" +#include "tensorrt_llm/kernels/quantization.cuh" +#include "tensorrt_llm/kernels/quantization.h" + +#include +#include + +namespace tensorrt_llm::kernels::cute_dsl +{ +namespace +{ +using ElemCopyType = uint4; +using SFCopyType = uint32_t; + +template +auto constexpr bitsPerElem() +{ +#ifdef ENABLE_FP4 + return std::is_same_v ? 4 : cute::sizeof_bits_v; +#else + return cute::sizeof_bits_v; +#endif +} + +template +auto constexpr elemPerCopy() +{ + return bitsPerElem() / bitsPerElem(); +} + +template +auto constexpr sfElemPerCopy() +{ + return bitsPerElem() / bitsPerElem(); +} +} // namespace + +template +__global__ void moePermuteKernel(InputType const* input, InputType* permuted_output, SFType const* input_sf, + SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const hidden_size, int32_t const top_k, int32_t const tile_size) +{ + int32_t constexpr kElemPerCopy = elemPerCopy(); + int32_t constexpr kSFElemPerCopy = sfElemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) + { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) + { + continue; + } + int32_t const expanded_idx = permuted_idx_to_expanded_idx[permuted_idx]; + int32_t const token_idx = expanded_idx / top_k; + + auto const* src_ptr = reinterpret_cast(input) + token_idx * kCopyPerToken; + auto* dst_ptr = reinterpret_cast(permuted_output) + permuted_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) + { + dst_ptr[i] = src_ptr[i]; + } + +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + int32_t const sf_hidden_size = hidden_size / kSFVecSize; + int64_t const kSFCopyPerToken = sf_hidden_size / kSFElemPerCopy; + auto const* sf_src_ptr = reinterpret_cast(input_sf); + auto* sf_dst_ptr = reinterpret_cast(permuted_sf); + for (int32_t i = threadIdx.x; i < kSFCopyPerToken; i += kThreadsPerBlock) + { + // input_sf is not swizzled, while permuted_sf is swizzled. + int64_t const src_offset = token_idx * kSFCopyPerToken + i; + int64_t const dst_offset = get_sf_out_offset_128x4(/* batchIdx= */ std::nullopt, permuted_idx, + i * kSFElemPerCopy, /* numRows= */ std::nullopt, sf_hidden_size) + / kSFElemPerCopy; + + sf_dst_ptr[dst_offset] = sf_src_ptr[src_offset]; + } + } +#endif + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf, + int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, + int32_t const top_k, int32_t const tile_size, cudaStream_t stream) +{ + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kSFVecSize = 16; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); + +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + int32_t constexpr kSFMAlignment = 128; + int32_t constexpr kSFKAlignment = 4; + int32_t constexpr kSFElemPerCopy = sfElemPerCopy(); + static_assert(kSFElemPerCopy == kSFKAlignment); + TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0, + "max_num_permuted_tokens must be divisible by %d.", kSFMAlignment); + TLLM_CHECK_WITH_INFO(hidden_size % (kSFVecSize * kSFKAlignment) == 0, "hidden_size must be divisible by %d.", + kSFVecSize * kSFKAlignment); + TLLM_CHECK_WITH_INFO(input_sf != nullptr, "input_sf is required for NVFP4."); + TLLM_CHECK_WITH_INFO(permuted_sf != nullptr, "permuted_sf is required for NVFP4."); + } +#endif + + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const blocks = std::min(smCount, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + auto kernel = &moePermuteKernel; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, permuted_output, input_sf, permuted_sf, tile_idx_to_mn_limit, + permuted_idx_to_expanded_idx, num_non_exiting_tiles, hidden_size, top_k, tile_size); +} + +#define INSTANTIATE_MOE_PERMUTE(InputType, SFType) \ + template void moePermute(InputType const* input, InputType* permuted_output, \ + SFType const* input_sf, SFType* permuted_sf, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* permuted_idx_to_expanded_idx, int32_t const* num_non_exiting_tiles, \ + int32_t const max_num_permuted_tokens, int32_t const hidden_size, int32_t const top_k, \ + int32_t const tile_size, cudaStream_t stream) + +INSTANTIATE_MOE_PERMUTE(half, uint8_t); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_PERMUTE(__nv_bfloat16, uint8_t); +#endif +#ifdef ENABLE_FP8 +INSTANTIATE_MOE_PERMUTE(__nv_fp8_e4m3, uint8_t); +#endif +#ifdef ENABLE_FP4 +INSTANTIATE_MOE_PERMUTE(__nv_fp4_e2m1, uint8_t); +#endif +#undef INSTANTIATE_MOE_PERMUTE + +template +__global__ void moeUnpermuteKernel(InputType const* permuted_input, InputType* output, + int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const hidden_size, + int32_t const top_k) +{ + using AccumType = float; + int32_t constexpr kElemPerCopy = elemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = hidden_size / kElemPerCopy; + InputType rmem[kElemPerCopy]; + AccumType rmemAccum[kElemPerCopy]; + + int32_t const token_idx = blockIdx.x; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + auto* dst_ptr = reinterpret_cast(output) + token_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) + { +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmemAccum[j] = 0; + } + for (int32_t k = 0; k < top_k; k++) + { + int32_t const permuted_idx = expanded_idx_to_permuted_idx[token_idx * top_k + k]; + if (permuted_idx < 0) + { + continue; + } + auto const* src_ptr = reinterpret_cast(permuted_input) + permuted_idx * kCopyPerToken; + *reinterpret_cast(rmem) = src_ptr[i]; + TopKScaleType const scale = topk_scales[token_idx * top_k + k]; + +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmemAccum[j] += static_cast(rmem[j]) * static_cast(scale); + } + } +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmem[j] = static_cast(rmemAccum[j]); + } + dst_ptr[i] = *reinterpret_cast(rmem); + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx, + TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, + cudaStream_t stream) +{ + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(hidden_size % kElemPerCopy == 0, "hidden_size must be divisible by %d.", kElemPerCopy); + + int32_t const blocks = num_tokens; + int32_t const threads = kThreadsPerBlock; + + auto kernel = &moeUnpermuteKernel; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx( + &config, kernel, permuted_input, output, expanded_idx_to_permuted_idx, topk_scales, hidden_size, top_k); +} + +#define INSTANTIATE_MOE_UNPERMUTE(InputType, TopKScaleType) \ + template void moeUnpermute(InputType const* permuted_input, InputType* output, \ + int32_t const* expanded_idx_to_permuted_idx, TopKScaleType const* topk_scales, int32_t const num_tokens, \ + int32_t const hidden_size, int32_t const top_k, cudaStream_t stream) + +INSTANTIATE_MOE_UNPERMUTE(half, float); +INSTANTIATE_MOE_UNPERMUTE(half, half); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, float); +INSTANTIATE_MOE_UNPERMUTE(__nv_bfloat16, __nv_bfloat16); +#endif +#undef INSTANTIATE_MOE_UNPERMUTE + +template +__global__ void moeActivationKernel(InputType const* input, OutputType* output, float const* global_sf, + SFType* output_sf, int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, + int32_t const interm_size, int32_t const tile_size) +{ + using ComputeType = float; +#ifdef ENABLE_FP4 + using ElemOutputCopyType = std::conditional_t, uint32_t, ElemCopyType>; +#else + using ElemOutputCopyType = ElemCopyType; +#endif + int32_t constexpr kElemPerCopy = elemPerCopy(); + // Need int64_t to prevent overflow when computing pointer offsets. + int64_t const kCopyPerToken = interm_size / kElemPerCopy; + InputType rmem[kElemPerCopy]; + InputType rmemGate[kElemPerCopy]; + ActFn act{}; + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif + + float global_sf_val = global_sf == nullptr ? 1.0f : global_sf[0]; + + int32_t const num_tokens = num_non_exiting_tiles[0] * tile_size; + for (int32_t permuted_idx = blockIdx.x; permuted_idx < num_tokens; permuted_idx += gridDim.x) + { + int32_t const tile_idx = permuted_idx / tile_size; + if (permuted_idx >= tile_idx_to_mn_limit[tile_idx]) + { + continue; + } + auto const* src_ptr + = reinterpret_cast(input) + permuted_idx * kCopyPerToken * (ActFn::IS_GLU ? 2 : 1); + auto* dst_ptr = reinterpret_cast(output) + permuted_idx * kCopyPerToken; + for (int32_t i = threadIdx.x; i < kCopyPerToken; i += kThreadsPerBlock) + { + *reinterpret_cast(rmem) = src_ptr[i]; + if constexpr (ActFn::IS_GLU) + { + *reinterpret_cast(rmemGate) = src_ptr[i + kCopyPerToken]; +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmem[j] = static_cast( + act(static_cast(rmemGate[j]), static_cast(rmem[j]))); + } + } + else + { +#pragma unroll + for (int32_t j = 0; j < kElemPerCopy; j++) + { + rmem[j] = static_cast(act(static_cast(rmem[j]))); + } + } + +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + auto* sf_dst_ptr = cvt_quant_get_sf_out_offset( + /* batchIdx= */ std::nullopt, permuted_idx, i, /*numRows=*/std::nullopt, interm_size / kSFVecSize, + output_sf, QuantizationSFLayout::SWIZZLED); + dst_ptr[i] = cvt_warp_fp16_to_fp4( + *reinterpret_cast*>(rmem), global_sf_val, sf_dst_ptr); + } + else +#endif + { + dst_ptr[i] = *reinterpret_cast(rmem); + } + } + } + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +template +void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, + int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, + cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens, + int32_t const interm_size, int32_t const tile_size, cudaStream_t stream) +{ + int32_t constexpr kThreadsPerBlock = 256; + int32_t constexpr kSFVecSize = 16; + int32_t constexpr kElemPerCopy = elemPerCopy(); + TLLM_CHECK_WITH_INFO(interm_size % kElemPerCopy == 0, "interm_size must be divisible by %d.", kElemPerCopy); + +#ifdef ENABLE_FP4 + if constexpr (std::is_same_v) + { + int32_t constexpr kSFMAlignment = 128; + int32_t constexpr kSFKAlignment = 4; + TLLM_CHECK_WITH_INFO(max_num_permuted_tokens % kSFMAlignment == 0, + "max_num_permuted_tokens must be divisible by %d.", kSFMAlignment); + TLLM_CHECK_WITH_INFO(interm_size % (kSFVecSize * kSFKAlignment) == 0, "interm_size must be divisible by %d.", + kSFVecSize * kSFKAlignment); + TLLM_CHECK_WITH_INFO(global_sf != nullptr, "global_sf is required for NVFP4."); + TLLM_CHECK_WITH_INFO(output_sf != nullptr, "output_sf is required for NVFP4."); + } +#endif + + static int32_t const smCount = tensorrt_llm::common::getMultiProcessorCount(); + int32_t const blocks = std::min(smCount, max_num_permuted_tokens); + int32_t const threads = kThreadsPerBlock; + + auto kernel_array + = std::array{&moeActivationKernel, kThreadsPerBlock>, + &moeActivationKernel, kThreadsPerBlock>, + &moeActivationKernel, kThreadsPerBlock>, + &moeActivationKernel, kThreadsPerBlock>, + &moeActivationKernel, kThreadsPerBlock>, + &moeActivationKernel, + &moeActivationKernel, kThreadsPerBlock>}; + + auto kernel = kernel_array[static_cast(activation_params.activation_type)]; + + cudaLaunchConfig_t config; + config.gridDim = blocks; + config.blockDim = threads; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel, input, output, global_sf, output_sf, tile_idx_to_mn_limit, + num_non_exiting_tiles, interm_size, tile_size); +} + +#define INSTANTIATE_MOE_ACTIVATION(InputType, OutputType, SFType) \ + template void moeActivation(InputType const* input, OutputType* output, \ + float const* global_sf, SFType* output_sf, int32_t const* tile_idx_to_mn_limit, \ + int32_t const* num_non_exiting_tiles, cutlass_kernels::ActivationParams activation_params, \ + int32_t const max_num_permuted_tokens, int32_t const interm_size, int32_t const tile_size, \ + cudaStream_t stream) + +INSTANTIATE_MOE_ACTIVATION(half, half, uint8_t); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_bfloat16, uint8_t); +#endif +#ifdef ENABLE_FP4 +INSTANTIATE_MOE_ACTIVATION(half, __nv_fp4_e2m1, uint8_t); +#ifdef ENABLE_BF16 +INSTANTIATE_MOE_ACTIVATION(__nv_bfloat16, __nv_fp4_e2m1, uint8_t); +#endif +#endif +#undef INSTANTIATE_MOE_ACTIVATION + +} // namespace tensorrt_llm::kernels::cute_dsl diff --git a/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h new file mode 100644 index 00000000000..0659b4c78f6 --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.h @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include "tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h" +#include +#include + +namespace tensorrt_llm::kernels::cute_dsl +{ +template +void moePermute(InputType const* input, InputType* permuted_output, SFType const* input_sf, SFType* permuted_sf, + int32_t const* tile_idx_to_mn_limit, int32_t const* permuted_idx_to_expanded_idx, + int32_t const* num_non_exiting_tiles, int32_t const max_num_permuted_tokens, int32_t const hidden_size, + int32_t const top_k, int32_t const tile_size, cudaStream_t stream); + +template +void moeUnpermute(InputType const* permuted_input, InputType* output, int32_t const* expanded_idx_to_permuted_idx, + TopKScaleType const* topk_scales, int32_t const num_tokens, int32_t const hidden_size, int32_t const top_k, + cudaStream_t stream); + +template +void moeActivation(InputType const* input, OutputType* output, float const* global_sf, SFType* output_sf, + int32_t const* tile_idx_to_mn_limit, int32_t const* num_non_exiting_tiles, + cutlass_kernels::ActivationParams activation_params, int32_t const max_num_permuted_tokens, + int32_t const interm_size, int32_t const tile_size, cudaStream_t stream); + +} // namespace tensorrt_llm::kernels::cute_dsl diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 3d9ee19d654..901ecbfff64 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -37,7 +37,6 @@ #include "cutlass/util/packed_stride.hpp" #include "cutlass/array.h" -#include "cutlass/epilogue/thread/activation.h" #include "cutlass/numeric_conversion.h" #include "cutlass/numeric_types.h" @@ -52,6 +51,7 @@ #include "tensorrt_llm/common/dataType.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h" +#include "tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cuh" #include "tensorrt_llm/kernels/moe_utils.cuh" #include "tensorrt_llm/kernels/preQuantScaleKernel.h" #include "tensorrt_llm/kernels/quantization.cuh" @@ -1344,7 +1344,7 @@ __host__ __device__ constexpr static U arrayConvert(T const& input) return converter(input); } -// Duplicated and permutes rows for MoE. In addition, reverse the permutation map to help with finalizing routing. +// Duplicated and permutes rows for MoE. // "expanded_x_row" simply means that the number of values is num_rows x k. It is "expanded" since we will have to // duplicate some rows in the input matrix to match the dimensions. Duplicates will always get routed to separate @@ -1937,56 +1937,6 @@ INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); #endif -// ============================== Activation Adaptors ================================= -template