Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ void help()
"- \"num_tokens\" - The total number of tokens to benchmark\n"
"- \"bias\" - If bias should be used, 0 = no bias, 1 = bias\n"
"- \"do_final_scale\" - If final scales should be applied, 0 = no scale, 1 = scale\n"
"- \"act_fn\" - The activation function to use, 0 = identity, 1 = relu, 2 = gelu, 3 = silu, 4 = geglu, 5 = "
"swiglu\n"
"- \"act_fn\" - The activation function to use, 1 = identity, 2 = gelu, 3 = relu, 4 = silu, 5 = swiglu, 6 = "
"geglu, 7 = swiglu_bias, 8 = relu2\n"
"- \"tactic_id1, tactic_id2\"\n"
"The config for the CUTLASS GEMM. tactic_idX sets the tactic for the corresponding GEMM"
"Valid tactics are:\n"
Expand Down
55 changes: 38 additions & 17 deletions cpp/tensorrt_llm/kernels/cuteDslKernels/moeUtils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ namespace
{
using ElemCopyType = uint4;
using SFCopyType = uint32_t;
using ActivationType = tensorrt_llm::kernels::cutlass_kernels::ActivationType;

template <typename T>
auto constexpr bitsPerElem()
Expand Down Expand Up @@ -385,23 +386,43 @@ void moeActivation(InputType const* input, OutputType* output, float const* glob
int32_t const blocks = std::min(smCount, max_num_permuted_tokens);
int32_t const threads = kThreadsPerBlock;

auto kernel_array
= std::array{&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
kThreadsPerBlock>,
&moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>};

auto kernel = kernel_array[static_cast<int32_t>(activation_params.activation_type)];
auto get_act_kernel = [](ActivationType activation_type) -> void (*)(InputType const* input, OutputType* output,
float const* global_sf, SFType* output_sf,
int32_t const* tile_idx_to_mn_limit,
int32_t const* num_non_exiting_tiles,
int32_t const interm_size, int32_t const tile_size)
{
switch (activation_type)
{
case ActivationType::Identity:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::Identity>, kThreadsPerBlock>;
case ActivationType::Gelu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Geglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::GELU>, kThreadsPerBlock>;
case ActivationType::Relu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::ReLu>, kThreadsPerBlock>;
case ActivationType::Silu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::IdentityAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::Swiglu:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize,
cutlass_kernels::GLUAdaptor<cutlass::epilogue::thread::SiLu>, kThreadsPerBlock>;
case ActivationType::SwigluBias:
return &moeActivationKernel<InputType, OutputType, SFType, kSFVecSize, cutlass_kernels::SwigluBiasAdaptor,
kThreadsPerBlock>;
case ActivationType::Relu2:
// Unsupported activation type
break;
}
TLLM_CHECK_WITH_INFO(false, "Unsupported activation type: %d", int(activation_type));
return nullptr;
};
auto kernel = get_act_kernel(activation_params.activation_type);

cudaLaunchConfig_t config;
config.gridDim = blocks;
Expand Down
18 changes: 9 additions & 9 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ namespace tensorrt_llm::kernels::cutlass_kernels
// cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu::doActivationKernel().
enum class ActivationType
{
Gelu = 0,
Relu,
Silu,
Swiglu,
Geglu,
SwigluBias,
Identity,
Relu2,
InvalidType
InvalidType = 0,
Identity = 1,
Gelu = 2,
Relu = 3,
Silu = 4,
Swiglu = 5,
Geglu = 6,
SwigluBias = 7,
Relu2 = 8,
};

} // namespace tensorrt_llm::kernels::cutlass_kernels
54 changes: 32 additions & 22 deletions cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2244,29 +2244,39 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8
{
// IMPORTANT: Keep the order of the activation functions in the same order as the ActivationType enum in
// common.h
auto fn = [&](auto block_scaling_type)
auto fn
= [&](auto block_scaling_type) -> void (*)(T*, GemmOutputType const*, float const*, ScaleBiasType const*,
bool, int64_t const*, int, int64_t, float const*, bool,
TmaWarpSpecializedGroupedGemmInput::ElementSF*, ActivationParams)
{
auto fn_list = std::array{
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::GELU>,
decltype(block_scaling_type)::value>, // Gelu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::ReLu>,
decltype(block_scaling_type)::value>, // Relu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::SiLu>,
decltype(block_scaling_type)::value>, // Silu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::SiLu>,
decltype(block_scaling_type)::value>, // Swiglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, GLUAdaptor<cutlass::epilogue::thread::GELU>,
decltype(block_scaling_type)::value>, // Geglu
&doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>, // SwigluBias
&doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>,
decltype(block_scaling_type)::value>, // Identity
&doActivationKernel<T, GemmOutputType, ScaleBiasType, IdentityAdaptor<cutlass::epilogue::thread::Relu2>,
decltype(block_scaling_type)::value> // Relu2

};
return fn_list[static_cast<int>(activation_type.activation_type)];
switch (activation_type.activation_type)
{
case ActivationType::Identity:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Identity>, decltype(block_scaling_type)::value>;
case ActivationType::Gelu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::Relu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::ReLu>, decltype(block_scaling_type)::value>;
case ActivationType::Silu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Swiglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::SiLu>, decltype(block_scaling_type)::value>;
case ActivationType::Geglu:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
GLUAdaptor<cutlass::epilogue::thread::GELU>, decltype(block_scaling_type)::value>;
case ActivationType::SwigluBias:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType, SwigluBiasAdaptor,
decltype(block_scaling_type)::value>;
case ActivationType::Relu2:
return &doActivationKernel<T, GemmOutputType, ScaleBiasType,
IdentityAdaptor<cutlass::epilogue::thread::Relu2>, decltype(block_scaling_type)::value>;
default: TLLM_CHECK_WITH_INFO(false, "Invalid activation type"); return nullptr;
}
};
auto NVFP4 = tensorrt_llm::common::ConstExprWrapper<TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType,
TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NVFP4>{};
Expand Down
18 changes: 9 additions & 9 deletions tensorrt_llm/_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@
# IMPORTANT: Keep the same order of activation functions in this enum and the enum in
# cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
class ActivationType(IntEnum):
Gelu = 0
Relu = 1
Silu = 2
Swiglu = 3
Geglu = 4
SwigluBias = 5
Identity = 6
Relu2 = 7
InvalidType = 8
InvalidType = 0
Identity = 1
Gelu = 2
Relu = 3
Silu = 4
Swiglu = 5
Geglu = 6
SwigluBias = 7
Relu2 = 8


def set_torch_compiling(enable: bool):
Expand Down
18 changes: 10 additions & 8 deletions tensorrt_llm/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import tensorrt as trt
import torch

from tensorrt_llm._torch.utils import ActivationType
from tensorrt_llm._utils import (get_init_params, str_dtype_to_torch,
str_dtype_to_trt)
from tensorrt_llm.layers.lora import LoraParams
Expand Down Expand Up @@ -49,14 +50,15 @@

activation_str_to_int_map = {
# [WARNING] Keep the below in sync with cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h
"gelu": 0,
"gelu_new": 0,
"relu": 1,
"silu": 2,
"swiglu": 3,
"geglu": 4,
"swiglu_bias": 5,
"identity": 6,
"gelu": int(ActivationType.Gelu),
"gelu_new": int(ActivationType.Gelu),
"relu": int(ActivationType.Relu),
"silu": int(ActivationType.Silu),
"swiglu": int(ActivationType.Swiglu),
"geglu": int(ActivationType.Geglu),
"swiglu_bias": int(ActivationType.SwigluBias),
"identity": int(ActivationType.Identity),
"relu2": int(ActivationType.Relu2),
}


Expand Down