Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
296 changes: 296 additions & 0 deletions csrc/trtllm_fused_moe_kernel_launcher.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ enum class Fp8QuantizationType {
DeepSeekFp8,
MxFp8,
PerTensorFp8,
PerChannelFp8,
};

inline std::string fp8QuantizationTypeToString(Fp8QuantizationType quantization_type) {
Expand All @@ -59,6 +60,8 @@ inline std::string fp8QuantizationTypeToString(Fp8QuantizationType quantization_
return "MxFp8";
case Fp8QuantizationType::PerTensorFp8:
return "PerTensorFp8";
case Fp8QuantizationType::PerChannelFp8:
return "PerChannelFp8";
}
}

Expand Down Expand Up @@ -769,6 +772,206 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher {
}
};

class Fp8PerChannelLauncher : public FusedMoeLauncher {
public:
static constexpr std::array<int32_t, 5> mSupportedTileNums = {8, 16, 32, 64, 128};

Fp8PerChannelLauncher(TensorView const& routing_logits, Optional<TensorView> const& routing_bias,
TensorView const& hidden_states, TensorView const& gemm1_weights,
TensorView const& gemm1_per_channel_weight_scale,
TensorView const& gemm1_per_channel_gate_weight_scale,
TensorView const& gemm2_weights,
TensorView const& gemm2_per_channel_weight_scale)
: FusedMoeLauncher(Optional<TensorView>(routing_logits), routing_bias, hidden_states,
gemm1_weights, Optional<TensorView>(), Optional<TensorView>(),
gemm2_weights, Optional<TensorView>()),
gemm1_per_channel_weight_scale_(gemm1_per_channel_weight_scale),
gemm1_per_channel_gate_weight_scale_(gemm1_per_channel_gate_weight_scale),
gemm2_per_channel_weight_scale_(gemm2_per_channel_weight_scale),
use_routing_scales_on_input(false) {}

void init(std::unique_ptr<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>&& args,
int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight,
int64_t weight_layout, bool use_routing_scales_on_input_param,
ActivationType activation_type) {
this->use_routing_scales_on_input = use_routing_scales_on_input_param;

auto dtype = hidden_states.dtype();
if (dtype == dl_float16) {
mDtypeAct = btg::Dtype::Fp16;
} else if (dtype == dl_bfloat16) {
mDtypeAct = btg::Dtype::Bfloat16;
} else if (dtype == dl_float8_e4m3fn) {
mDtypeAct = btg::Dtype::E4m3;
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError)
<< "Unsupported input dtype for FP8 per-channel MoE.";
}
mDtypeWeights = btg::Dtype::E4m3;

FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type,
use_shuffled_weight, weight_layout, activation_type);
}

void check_routing() const override { FusedMoeLauncher::check_routing_common(); }

void prepare_routing() override {
FusedMoeLauncher::prepare_routing_common();

auto dtype = hidden_states.dtype();
if (dtype == dl_float16) {
args->mDtypeElt = btg::Dtype::Fp16;
} else if (dtype == dl_bfloat16) {
args->mDtypeElt = btg::Dtype::Bfloat16;
} else if (dtype == dl_float8_e4m3fn) {
args->mDtypeElt = btg::Dtype::E4m3;
} else {
TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE.";
}

args->mDtypeOut = btg::Dtype::Bfloat16;
args->mUseDeepSeekFp8 = false;

auto const routing_bias_dtype =
routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16;
mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32;

expert_weights =
alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device());

workspace.expert_weights = expert_weights.data_ptr();
if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::Llama4) {
workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel
}
}

void check_moe() const override {
FusedMoeLauncher::check_moe_common();

TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.dtype(), dl_float32)
<< "gemm1_per_channel_weight_scale must be float32.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_weight_scale must be 2D [local_num_experts, 2*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1), 2 * args->intermediate_size)
<< "gemm1_per_channel_weight_scale dim 1 must be 2*intermediate_size.";
Comment on lines +853 to +858
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The dimension checks and error messages for gemm1_per_channel_weight_scale_ hardcode the multiplier 2, which assumes a gated activation. It would be more robust to use the intermediate_size_factor member from the base class, which is correctly set based on the activation type. This will ensure correctness if non-gated activations are used with this launcher in the future.

    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2)
        << "gemm1_per_channel_weight_scale must be 2D [local_num_experts, "
           "intermediate_size_factor*intermediate_size].";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts)
        << "gemm1_per_channel_weight_scale dim 0 must match local_num_experts.";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1),
                      intermediate_size_factor * args->intermediate_size)
        << "gemm1_per_channel_weight_scale dim 1 must match intermediate_size_factor * "
           "intermediate_size.";


TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.dtype(), dl_float32)
<< "gemm1_per_channel_gate_weight_scale must be float32.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2)
<< "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, "
"2*intermediate_size].";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts)
<< "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1), 2 * args->intermediate_size)
<< "gemm1_per_channel_gate_weight_scale dim 1 must be 2*intermediate_size.";
Comment on lines +862 to +868
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the check for gemm1_per_channel_weight_scale_, the dimension checks and error messages for gemm1_per_channel_gate_weight_scale_ hardcode the multiplier 2. Using intermediate_size_factor would make this more robust and consistent.

    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2)
        << "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, "
           "intermediate_size_factor*intermediate_size].";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts)
        << "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts.";
    TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1),
                      intermediate_size_factor * args->intermediate_size)
        << "gemm1_per_channel_gate_weight_scale dim 1 must match intermediate_size_factor * "
           "intermediate_size.";


TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.dtype(), dl_float32)
<< "gemm2_per_channel_weight_scale must be float32.";
TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.ndim(), 2)
<< "gemm2_per_channel_weight_scale must be 2D [local_num_experts, hidden_size].";
TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.size(0), args->local_num_experts)
<< "gemm2_per_channel_weight_scale dim 0 must match local_num_experts.";
TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.size(1), args->hidden_size)
<< "gemm2_per_channel_weight_scale dim 1 must match hidden_size.";

TVM_FFI_ICHECK(hidden_states.dtype() == dl_float8_e4m3fn ||
hidden_states.dtype() == dl_float16 || hidden_states.dtype() == dl_bfloat16)
<< "FP8 per-channel MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn)
<< "FP8 per-channel MoE: gemm1_weights must be float8_e4m3fn.";
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn)
<< "FP8 per-channel MoE: gemm2_weights must be float8_e4m3fn.";
}

void prepare_moe(int64_t& moe_tactic) override {
FusedMoeLauncher::prepare_moe_common(moe_tactic);

int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + args->num_experts;
int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens;

gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * args->intermediate_size},
dl_uint8, hidden_states.device());
gemm1_output_scale =
alloc_tensor({2 * args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32,
hidden_states.device());

activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size},
dl_uint8, hidden_states.device());
activation_output_scale =
alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32,
hidden_states.device());

gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16,
hidden_states.device());

workspace.hidden_states_scale_linear = nullptr;
workspace.gemm1_output = gemm1_output.data_ptr();
workspace.gemm1_output_scale = static_cast<float*>(gemm1_output_scale.data_ptr());
workspace.activation_output = activation_output.data_ptr();
workspace.activation_output_scale = static_cast<float*>(activation_output_scale.data_ptr());
workspace.gemm2_output = gemm2_output.data_ptr();
workspace.gemm2_output_scale = nullptr;

if (args->output == nullptr) {
output =
alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device());
args->output = output.data_ptr();
}
args->output_scale = nullptr;

args->output1_scales_scalar = nullptr;
args->output1_scales_gate_scalar = nullptr;
args->output2_scales_scalar = nullptr;

args->gemm1_per_channel_weight_scale =
static_cast<float*>(gemm1_per_channel_weight_scale_.data_ptr());
args->gemm1_per_channel_gate_weight_scale =
static_cast<float*>(gemm1_per_channel_gate_weight_scale_.data_ptr());
args->gemm2_per_channel_weight_scale =
static_cast<float*>(gemm2_per_channel_weight_scale_.data_ptr());
}

private:
bool use_routing_scales_on_input;
TensorView gemm1_per_channel_weight_scale_;
TensorView gemm1_per_channel_gate_weight_scale_;
TensorView gemm2_per_channel_weight_scale_;
Tensor gemm1_output_scale;
Tensor activation_output_scale;

public:
static Array<Array<int64_t>> getValidConfigs(int64_t top_k, int64_t hidden_size,
int64_t intermediate_size, int64_t num_local_experts,
int64_t num_tokens, int64_t act_type,
bool use_shuffled_weight, int64_t weight_layout,
btg::Dtype dtype_act, btg::Dtype dtype_weights) {
Array<Array<int64_t>> valid_configs;

std::vector<int32_t> supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end());
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts);

for (int32_t tile_N : selected_tile_nums) {
auto moe_runner = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner>(
dtype_act, dtype_weights,
false, // useDeepSeekFp8
tile_N, static_cast<ActivationType>(act_type), use_shuffled_weight,
static_cast<batchedGemm::gemm::MatrixLayout>(weight_layout));

auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size,
num_local_experts, num_tokens);

for (auto cfg : cfgs) {
valid_configs.push_back({tile_N, cfg});
}
}

return valid_configs;
}
};

class Fp8BlockScaleLauncher : public FusedMoeLauncher {
public:
static constexpr std::array<int32_t, 5> mBaseSupportedTileNums = {8, 16, 32, 64, 128};
Expand Down Expand Up @@ -1774,6 +1977,93 @@ Array<Tensor> trtllm_fp8_per_tensor_scale_moe(
return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input);
}

Array<Tensor> trtllm_fp8_per_channel_scale_moe(
TensorView routing_logits, Optional<TensorView> routing_bias, TensorView hidden_states,
TensorView gemm1_weights, TensorView gemm1_per_channel_weight_scale,
TensorView gemm1_per_channel_gate_weight_scale, TensorView gemm2_weights,
TensorView gemm2_per_channel_weight_scale, TensorView output, int64_t num_experts,
int64_t top_k, Optional<int64_t> n_group, Optional<int64_t> topk_group,
int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts,
Optional<double> routed_scaling_factor, bool use_routing_scales_on_input,
int64_t routing_method_type, bool do_finalize, bool enable_pdl, Array<int64_t> config_index,
int64_t activation_type) {
// Basic type validation
auto dtype = hidden_states.dtype();
auto activation = static_cast<ActivationType>(activation_type);
if (use_routing_scales_on_input) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
} else if (static_cast<RoutingMethodType>(routing_method_type) == RoutingMethodType::DeepSeekV3) {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float.";
} else {
TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16.";
}
TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16)
<< "FP8 per-channel MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16.";
TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn)
<< "FP8 per-channel MoE: gemm1_weights must be float8_e4m3fn.";
TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn)
<< "FP8 per-channel MoE: gemm2_weights must be float8_e4m3fn.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale.dtype(), dl_float32)
<< "FP8 per-channel MoE: gemm1_per_channel_weight_scale must be float32.";
TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale.dtype(), dl_float32)
<< "FP8 per-channel MoE: gemm1_per_channel_gate_weight_scale must be float32.";
TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale.dtype(), dl_float32)
<< "FP8 per-channel MoE: gemm2_per_channel_weight_scale must be float32.";

auto const num_tokens = hidden_states.size(0);
auto const hidden_size = hidden_states.size(1);

bool use_shuffled_weight = true;
int64_t weight_layout = 0; // MajorK

// Calculate supported tile sizes
std::vector<int32_t> mSupportedTileN(Fp8PerChannelLauncher::mSupportedTileNums.begin(),
Fp8PerChannelLauncher::mSupportedTileNums.end());
std::set<int32_t> selected_tile_nums =
computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts);

// Create a map of launchers for each tile size
std::unordered_map<int32_t, std::unique_ptr<Fp8PerChannelLauncher>> launchers_map;

for (int32_t curr_tile_N : selected_tile_nums) {
auto args = std::make_unique<tensorrt_llm::kernels::trtllmgen_moe::MoE::MoERunnerArgs>();
args->num_tokens = num_tokens;
args->num_experts = num_experts;
args->hidden_size = hidden_size;
args->hidden_size_output = args->hidden_size;
args->top_k = top_k;
args->n_group = n_group.value_or(0);
args->topk_group = topk_group.value_or(0);
args->local_expert_offset = local_expert_offset;
args->local_num_experts = local_num_experts;
args->intermediate_size = intermediate_size;
args->routed_scaling_factor = routed_scaling_factor.value_or(1.0);
args->do_finalize = do_finalize;
args->output = output.data_ptr();
args->output_scale = nullptr;

auto launcher = std::make_unique<Fp8PerChannelLauncher>(
routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_per_channel_weight_scale,
gemm1_per_channel_gate_weight_scale, gemm2_weights, gemm2_per_channel_weight_scale);
launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight,
weight_layout, use_routing_scales_on_input, activation);

launchers_map[curr_tile_N] = std::move(launcher);
}

// Extract tile_N and config from config_index
int64_t tile_N = config_index[0];
int64_t config = config_index[1];

if (tile_N == -1 || config == -1) {
tile_N = *selected_tile_nums.begin();
}

auto& selected_launcher = launchers_map.at(tile_N);

return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input);
}

Array<Tensor> trtllm_fp8_block_scale_moe(
Optional<TensorView> routing_logits, TensorView expert_indices, TensorView expert_weights,
Optional<TensorView> routing_bias, TensorView hidden_states, TensorView hidden_states_scale,
Expand Down Expand Up @@ -2155,6 +2445,11 @@ Array<Array<int64_t>> trtllm_get_valid_moe_configs(
return Fp8BlockScaleLauncher::getValidConfigs(
top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight,
weight_layout, dtype_weights, quantization_type);
} else if (quantization_type == Fp8QuantizationType::PerChannelFp8) {
// FP8 per-channel scale
return Fp8PerChannelLauncher::getValidConfigs(
top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type,
use_shuffled_weight, weight_layout, dtype_act, dtype_weights);
} else {
// FP8 per-tensor scale
return Fp8PerTensorLauncher::getValidConfigs(
Expand Down Expand Up @@ -2188,6 +2483,7 @@ namespace trtllm_cubin_loader {

TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_bf16_moe, trtllm_bf16_moe);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_per_tensor_scale_moe);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_channel_scale_moe, trtllm_fp8_per_channel_scale_moe);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_moe);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_moe);
TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mxint4_block_scale_moe, trtllm_mxint4_block_scale_moe);
Expand Down
Loading
Loading