diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index f3957e9717..8c6aaecd18 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -46,6 +46,7 @@ enum class Fp8QuantizationType { DeepSeekFp8, MxFp8, PerTensorFp8, + PerChannelFp8, }; inline std::string fp8QuantizationTypeToString(Fp8QuantizationType quantization_type) { @@ -59,6 +60,8 @@ inline std::string fp8QuantizationTypeToString(Fp8QuantizationType quantization_ return "MxFp8"; case Fp8QuantizationType::PerTensorFp8: return "PerTensorFp8"; + case Fp8QuantizationType::PerChannelFp8: + return "PerChannelFp8"; } } @@ -769,6 +772,206 @@ class Fp8PerTensorLauncher : public FusedMoeLauncher { } }; +class Fp8PerChannelLauncher : public FusedMoeLauncher { + public: + static constexpr std::array mSupportedTileNums = {8, 16, 32, 64, 128}; + + Fp8PerChannelLauncher(TensorView const& routing_logits, Optional const& routing_bias, + TensorView const& hidden_states, TensorView const& gemm1_weights, + TensorView const& gemm1_per_channel_weight_scale, + TensorView const& gemm1_per_channel_gate_weight_scale, + TensorView const& gemm2_weights, + TensorView const& gemm2_per_channel_weight_scale) + : FusedMoeLauncher(Optional(routing_logits), routing_bias, hidden_states, + gemm1_weights, Optional(), Optional(), + gemm2_weights, Optional()), + gemm1_per_channel_weight_scale_(gemm1_per_channel_weight_scale), + gemm1_per_channel_gate_weight_scale_(gemm1_per_channel_gate_weight_scale), + gemm2_per_channel_weight_scale_(gemm2_per_channel_weight_scale), + use_routing_scales_on_input(false) {} + + void init(std::unique_ptr&& args, + int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, + int64_t weight_layout, bool use_routing_scales_on_input_param, + ActivationType activation_type) { + this->use_routing_scales_on_input = use_routing_scales_on_input_param; + + auto dtype = hidden_states.dtype(); + if (dtype == dl_float16) { + mDtypeAct = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + mDtypeAct = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + mDtypeAct = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) + << "Unsupported input dtype for FP8 per-channel MoE."; + } + mDtypeWeights = btg::Dtype::E4m3; + + FusedMoeLauncher::init_common(std::move(args), tile_tokens_dim, routing_method_type, + use_shuffled_weight, weight_layout, activation_type); + } + + void check_routing() const override { FusedMoeLauncher::check_routing_common(); } + + void prepare_routing() override { + FusedMoeLauncher::prepare_routing_common(); + + auto dtype = hidden_states.dtype(); + if (dtype == dl_float16) { + args->mDtypeElt = btg::Dtype::Fp16; + } else if (dtype == dl_bfloat16) { + args->mDtypeElt = btg::Dtype::Bfloat16; + } else if (dtype == dl_float8_e4m3fn) { + args->mDtypeElt = btg::Dtype::E4m3; + } else { + TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; + } + + args->mDtypeOut = btg::Dtype::Bfloat16; + args->mUseDeepSeekFp8 = false; + + auto const routing_bias_dtype = + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; + mRoutingBiasDtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; + + expert_weights = + alloc_tensor({args->num_tokens, args->top_k}, dl_bfloat16, hidden_states.device()); + + workspace.expert_weights = expert_weights.data_ptr(); + if (static_cast(routing_method_type) == RoutingMethodType::Llama4) { + workspace.token_scales = expert_weights.data_ptr(); // Consumed by permuteGemm1 kernel + } + } + + void check_moe() const override { + FusedMoeLauncher::check_moe_common(); + + TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.dtype(), dl_float32) + << "gemm1_per_channel_weight_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.ndim(), 2) + << "gemm1_per_channel_weight_scale must be 2D [local_num_experts, 2*intermediate_size]."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(0), args->local_num_experts) + << "gemm1_per_channel_weight_scale dim 0 must match local_num_experts."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale_.size(1), 2 * args->intermediate_size) + << "gemm1_per_channel_weight_scale dim 1 must be 2*intermediate_size."; + + TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.dtype(), dl_float32) + << "gemm1_per_channel_gate_weight_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.ndim(), 2) + << "gemm1_per_channel_gate_weight_scale must be 2D [local_num_experts, " + "2*intermediate_size]."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(0), args->local_num_experts) + << "gemm1_per_channel_gate_weight_scale dim 0 must match local_num_experts."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale_.size(1), 2 * args->intermediate_size) + << "gemm1_per_channel_gate_weight_scale dim 1 must be 2*intermediate_size."; + + TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.dtype(), dl_float32) + << "gemm2_per_channel_weight_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.ndim(), 2) + << "gemm2_per_channel_weight_scale must be 2D [local_num_experts, hidden_size]."; + TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.size(0), args->local_num_experts) + << "gemm2_per_channel_weight_scale dim 0 must match local_num_experts."; + TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale_.size(1), args->hidden_size) + << "gemm2_per_channel_weight_scale dim 1 must match hidden_size."; + + TVM_FFI_ICHECK(hidden_states.dtype() == dl_float8_e4m3fn || + hidden_states.dtype() == dl_float16 || hidden_states.dtype() == dl_bfloat16) + << "FP8 per-channel MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) + << "FP8 per-channel MoE: gemm1_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) + << "FP8 per-channel MoE: gemm2_weights must be float8_e4m3fn."; + } + + void prepare_moe(int64_t& moe_tactic) override { + FusedMoeLauncher::prepare_moe_common(moe_tactic); + + int32_t max_num_padded_tokens_gemm1 = workspace.total_max_padded_tokens + args->num_experts; + int32_t max_num_padded_tokens_gemm2 = workspace.total_max_padded_tokens; + + gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * args->intermediate_size}, + dl_uint8, hidden_states.device()); + gemm1_output_scale = + alloc_tensor({2 * args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + + activation_output = alloc_tensor({max_num_padded_tokens_gemm1, args->intermediate_size}, + dl_uint8, hidden_states.device()); + activation_output_scale = + alloc_tensor({args->intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, + hidden_states.device()); + + gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args->hidden_size}, dl_bfloat16, + hidden_states.device()); + + workspace.hidden_states_scale_linear = nullptr; + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); + workspace.gemm2_output = gemm2_output.data_ptr(); + workspace.gemm2_output_scale = nullptr; + + if (args->output == nullptr) { + output = + alloc_tensor({args->num_tokens, args->hidden_size}, dl_bfloat16, hidden_states.device()); + args->output = output.data_ptr(); + } + args->output_scale = nullptr; + + args->output1_scales_scalar = nullptr; + args->output1_scales_gate_scalar = nullptr; + args->output2_scales_scalar = nullptr; + + args->gemm1_per_channel_weight_scale = + static_cast(gemm1_per_channel_weight_scale_.data_ptr()); + args->gemm1_per_channel_gate_weight_scale = + static_cast(gemm1_per_channel_gate_weight_scale_.data_ptr()); + args->gemm2_per_channel_weight_scale = + static_cast(gemm2_per_channel_weight_scale_.data_ptr()); + } + + private: + bool use_routing_scales_on_input; + TensorView gemm1_per_channel_weight_scale_; + TensorView gemm1_per_channel_gate_weight_scale_; + TensorView gemm2_per_channel_weight_scale_; + Tensor gemm1_output_scale; + Tensor activation_output_scale; + + public: + static Array> getValidConfigs(int64_t top_k, int64_t hidden_size, + int64_t intermediate_size, int64_t num_local_experts, + int64_t num_tokens, int64_t act_type, + bool use_shuffled_weight, int64_t weight_layout, + btg::Dtype dtype_act, btg::Dtype dtype_weights) { + Array> valid_configs; + + std::vector supported_tile_nums(mSupportedTileNums.begin(), mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(supported_tile_nums, num_tokens, top_k, num_local_experts); + + for (int32_t tile_N : selected_tile_nums) { + auto moe_runner = std::make_unique( + dtype_act, dtype_weights, + false, // useDeepSeekFp8 + tile_N, static_cast(act_type), use_shuffled_weight, + static_cast(weight_layout)); + + auto cfgs = moe_runner->getValidConfigIndices(top_k, hidden_size, intermediate_size, + num_local_experts, num_tokens); + + for (auto cfg : cfgs) { + valid_configs.push_back({tile_N, cfg}); + } + } + + return valid_configs; + } +}; + class Fp8BlockScaleLauncher : public FusedMoeLauncher { public: static constexpr std::array mBaseSupportedTileNums = {8, 16, 32, 64, 128}; @@ -1774,6 +1977,93 @@ Array trtllm_fp8_per_tensor_scale_moe( return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input); } +Array trtllm_fp8_per_channel_scale_moe( + TensorView routing_logits, Optional routing_bias, TensorView hidden_states, + TensorView gemm1_weights, TensorView gemm1_per_channel_weight_scale, + TensorView gemm1_per_channel_gate_weight_scale, TensorView gemm2_weights, + TensorView gemm2_per_channel_weight_scale, TensorView output, int64_t num_experts, + int64_t top_k, Optional n_group, Optional topk_group, + int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, + Optional routed_scaling_factor, bool use_routing_scales_on_input, + int64_t routing_method_type, bool do_finalize, bool enable_pdl, Array config_index, + int64_t activation_type) { + // Basic type validation + auto dtype = hidden_states.dtype(); + auto activation = static_cast(activation_type); + if (use_routing_scales_on_input) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + } else if (static_cast(routing_method_type) == RoutingMethodType::DeepSeekV3) { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; + } else { + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; + } + TVM_FFI_ICHECK(dtype == dl_float8_e4m3fn || dtype == dl_float16 || dtype == dl_bfloat16) + << "FP8 per-channel MoE: hidden_states must be float8_e4m3fn, float16, or bfloat16."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) + << "FP8 per-channel MoE: gemm1_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) + << "FP8 per-channel MoE: gemm2_weights must be float8_e4m3fn."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_weight_scale.dtype(), dl_float32) + << "FP8 per-channel MoE: gemm1_per_channel_weight_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm1_per_channel_gate_weight_scale.dtype(), dl_float32) + << "FP8 per-channel MoE: gemm1_per_channel_gate_weight_scale must be float32."; + TVM_FFI_ICHECK_EQ(gemm2_per_channel_weight_scale.dtype(), dl_float32) + << "FP8 per-channel MoE: gemm2_per_channel_weight_scale must be float32."; + + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); + + bool use_shuffled_weight = true; + int64_t weight_layout = 0; // MajorK + + // Calculate supported tile sizes + std::vector mSupportedTileN(Fp8PerChannelLauncher::mSupportedTileNums.begin(), + Fp8PerChannelLauncher::mSupportedTileNums.end()); + std::set selected_tile_nums = + computeSelectedTileN(mSupportedTileN, num_tokens, top_k, local_num_experts); + + // Create a map of launchers for each tile size + std::unordered_map> launchers_map; + + for (int32_t curr_tile_N : selected_tile_nums) { + auto args = std::make_unique(); + args->num_tokens = num_tokens; + args->num_experts = num_experts; + args->hidden_size = hidden_size; + args->hidden_size_output = args->hidden_size; + args->top_k = top_k; + args->n_group = n_group.value_or(0); + args->topk_group = topk_group.value_or(0); + args->local_expert_offset = local_expert_offset; + args->local_num_experts = local_num_experts; + args->intermediate_size = intermediate_size; + args->routed_scaling_factor = routed_scaling_factor.value_or(1.0); + args->do_finalize = do_finalize; + args->output = output.data_ptr(); + args->output_scale = nullptr; + + auto launcher = std::make_unique( + routing_logits, routing_bias, hidden_states, gemm1_weights, gemm1_per_channel_weight_scale, + gemm1_per_channel_gate_weight_scale, gemm2_weights, gemm2_per_channel_weight_scale); + launcher->init(std::move(args), curr_tile_N, routing_method_type, use_shuffled_weight, + weight_layout, use_routing_scales_on_input, activation); + + launchers_map[curr_tile_N] = std::move(launcher); + } + + // Extract tile_N and config from config_index + int64_t tile_N = config_index[0]; + int64_t config = config_index[1]; + + if (tile_N == -1 || config == -1) { + tile_N = *selected_tile_nums.begin(); + } + + auto& selected_launcher = launchers_map.at(tile_N); + + return selected_launcher->run(config, enable_pdl, use_routing_scales_on_input); +} + Array trtllm_fp8_block_scale_moe( Optional routing_logits, TensorView expert_indices, TensorView expert_weights, Optional routing_bias, TensorView hidden_states, TensorView hidden_states_scale, @@ -2155,6 +2445,11 @@ Array> trtllm_get_valid_moe_configs( return Fp8BlockScaleLauncher::getValidConfigs( top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, use_shuffled_weight, weight_layout, dtype_weights, quantization_type); + } else if (quantization_type == Fp8QuantizationType::PerChannelFp8) { + // FP8 per-channel scale + return Fp8PerChannelLauncher::getValidConfigs( + top_k, hidden_size, intermediate_size, num_local_experts, num_tokens, act_type, + use_shuffled_weight, weight_layout, dtype_act, dtype_weights); } else { // FP8 per-tensor scale return Fp8PerTensorLauncher::getValidConfigs( @@ -2188,6 +2483,7 @@ namespace trtllm_cubin_loader { TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_bf16_moe, trtllm_bf16_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_tensor_scale_moe, trtllm_fp8_per_tensor_scale_moe); +TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_per_channel_scale_moe, trtllm_fp8_per_channel_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_fp4_block_scale_moe, trtllm_fp4_block_scale_moe); TVM_FFI_DLL_EXPORT_TYPED_FUNC(trtllm_mxint4_block_scale_moe, trtllm_mxint4_block_scale_moe); diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index 4091019efc..b7177bcbcf 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -288,13 +288,14 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, - int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { + int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl, + float* perChannelWeightScale) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); int32_t intermediateSizeFactor = (isGatedActivation(mActType) ? 2 : 1); mRunner.run(numTokens, intermediateSizeFactor * intermediateSize, hiddenSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, hiddenState, hiddenStateScale, weights, - weightsScale, expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, + weightsScale, expertWeights, perChannelWeightScale, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm1Workspace, stream, device, @@ -381,14 +382,16 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, - cudaStream_t stream, int32_t configIndex, bool enable_pdl) { + cudaStream_t stream, int32_t configIndex, bool enable_pdl, + float* perChannelWeightScale) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run( numTokens, hiddenSize, intermediateSize, {}, numTokens, numExperts, maxNumCtasInBatchDim, permutedHiddenState, permutedHiddenStateScale, weights, weightsScale, /* perTokensSfA */ nullptr, - /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, + /* perTokensSfB */ perChannelWeightScale, outputScalesScalar, + /* outputScalesGateScalar */ nullptr, ptrBias, /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* clampLimit */ nullptr, output, outputScale, /* permutedIdxToTokenIdx */ nullptr, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex, @@ -584,16 +587,16 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d auto const& config = mPassingConfigs[configIndex]; - mPermuteGemm1.run(args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, - args.gemm1_weights_scale, workspace.token_scales, args.output1_scales_scalar, - args.output1_scales_gate_scalar, args.gemm1_bias, args.gemm1_alpha, - args.gemm1_beta, args.gemm1_clamp_limit, workspace.gemm1_output, - workspace.gemm1_output_scale, args.top_k, args.hidden_size, - args.intermediate_size, args.local_num_experts, args.num_tokens, - workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, - workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, - workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, - args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); + mPermuteGemm1.run( + args.hidden_states, hidden_states_scale_linear, args.gemm1_weights, args.gemm1_weights_scale, + workspace.token_scales, args.output1_scales_scalar, args.output1_scales_gate_scalar, + args.gemm1_bias, args.gemm1_alpha, args.gemm1_beta, args.gemm1_clamp_limit, + workspace.gemm1_output, workspace.gemm1_output_scale, args.top_k, args.hidden_size, + args.intermediate_size, args.local_num_experts, args.num_tokens, + workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, + workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, + workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, args.mUseRoutingScalesOnInput, + device, stream, config.gemm1Config, enable_pdl, args.gemm1_per_channel_weight_scale); // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. void* gemm2_input = workspace.gemm1_output; @@ -613,7 +616,7 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, - config.gemm2Config, enable_pdl); + config.gemm2Config, enable_pdl, args.gemm2_per_channel_weight_scale); // Run finalize if (args.do_finalize) { diff --git a/flashinfer/fused_moe/__init__.py b/flashinfer/fused_moe/__init__.py index e2b4cab3d6..1cb1117b63 100644 --- a/flashinfer/fused_moe/__init__.py +++ b/flashinfer/fused_moe/__init__.py @@ -31,6 +31,7 @@ trtllm_fp4_block_scale_routed_moe, trtllm_fp8_block_scale_moe, trtllm_fp8_block_scale_routed_moe, + trtllm_fp8_per_channel_scale_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_bf16_moe, trtllm_bf16_routed_moe, @@ -71,6 +72,7 @@ "trtllm_fp4_block_scale_routed_moe", "trtllm_fp8_block_scale_moe", "trtllm_fp8_block_scale_routed_moe", + "trtllm_fp8_per_channel_scale_moe", "trtllm_fp8_per_tensor_scale_moe", "trtllm_mxint4_block_scale_moe", "fused_topk_deepseek", diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 7e0760e7b2..da53f82d10 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -191,6 +191,10 @@ class Fp8QuantizationType(IntEnum): DeepSeekFp8 = 1 # MxFp8 x MxFp8 MxFp8 = 2 + # Per-tensor FP8 + PerTensorFp8 = 3 + # Per-channel FP8 + PerChannelFp8 = 4 @functools.cache @@ -1191,6 +1195,33 @@ def forward( [-1, -1] if tactic == -1 else tactic, self.fp8_quantization_type, ) + elif self.fp8_quantization_type == Fp8QuantizationType.PerChannelFp8: + # FP8 per-channel scale + moe_op.trtllm_fp8_per_channel_scale_moe( + routing_logits, + kwargs["routing_bias"], + hidden_states, + kwargs["gemm1_weights"], + kwargs["gemm1_per_channel_weight_scale"], + kwargs["gemm1_per_channel_gate_weight_scale"], + kwargs["gemm2_weights"], + kwargs["gemm2_per_channel_weight_scale"], + output, + kwargs["num_experts"], + self.top_k, + kwargs["n_group"], + kwargs["topk_group"], + self.intermediate_size, + kwargs["local_expert_offset"], + self.num_local_experts, + kwargs["routed_scaling_factor"], + kwargs["use_routing_scales_on_input"], + kwargs["routing_method_type"], + kwargs["do_finalize"], + kwargs["enable_pdl"], + [-1, -1] if tactic == -1 else tactic, + self.activation_type, + ) else: # FP8 per tensor scale moe_op.trtllm_fp8_per_tensor_scale_moe( @@ -1626,6 +1657,156 @@ def _fake_trtllm_fp8_per_tensor_scale_moe( return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( + "flashinfer::trtllm_fp8_per_channel_scale_moe", + mutates_args=(""), + ) + def trtllm_fp8_per_channel_scale_moe_op( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_per_channel_weight_scale: torch.Tensor, + gemm1_per_channel_gate_weight_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_per_channel_weight_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + use_routing_scales_on_input: bool, + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, + ) -> List[torch.Tensor]: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) + tuner = AutoTuner.get() + MoERunner.refine_tuning_config(tune_max_num_tokens) + + num_tokens = hidden_states.shape[0] + hidden_size = hidden_states.shape[-1] + + output = torch.empty( + num_tokens, hidden_size, dtype=torch.bfloat16, device=hidden_states.device + ) + topk_ids = torch.empty( + num_tokens, top_k, dtype=torch.int32, device=hidden_states.device + ) + expert_weights = torch.empty( + num_tokens, top_k, dtype=routing_logits.dtype, device=hidden_states.device + ) + + dtype_act = DtypeTrtllmGen.E4m3 + dtype_weights = DtypeTrtllmGen.E4m3 + + moe_runner = MoERunner( + top_k=top_k, + num_local_experts=local_num_experts, + dtype_act=dtype_act, + dtype_weights=dtype_weights, + fp8_quantization_type=Fp8QuantizationType.PerChannelFp8, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + weight_layout=WeightLayout.MajorK, + use_shuffled_weight=True, + activation_type=activation_type, + ) + + inputs = [output, routing_logits, topk_ids, expert_weights, hidden_states] + + _, tactic = tuner.choose_one( + "flashinfer::trtllm_fp8_per_channel_scale_moe", + [moe_runner], + MoERunner.tuning_config_no_hidden_states_scales, + inputs, + routing_bias=routing_bias, + gemm1_weights=gemm1_weights, + gemm1_per_channel_weight_scale=gemm1_per_channel_weight_scale, + gemm1_per_channel_gate_weight_scale=gemm1_per_channel_gate_weight_scale, + gemm2_weights=gemm2_weights, + gemm2_per_channel_weight_scale=gemm2_per_channel_weight_scale, + num_experts=num_experts, + n_group=n_group, + topk_group=topk_group, + local_expert_offset=local_expert_offset, + local_num_experts=local_num_experts, + routed_scaling_factor=routed_scaling_factor, + use_routing_scales_on_input=use_routing_scales_on_input, + routing_method_type=routing_method_type, + do_finalize=do_finalize, + enable_pdl=enable_pdl, + activation_type=activation_type, + ) + intermediate_output = moe_op.trtllm_fp8_per_channel_scale_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm1_per_channel_weight_scale, + gemm1_per_channel_gate_weight_scale, + gemm2_weights, + gemm2_per_channel_weight_scale, + output, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + use_routing_scales_on_input, + routing_method_type, + do_finalize, + enable_pdl, + [-1, -1] if tactic == -1 else tactic, + activation_type, + ) + if do_finalize: + return [output] + else: + return [ + torch.from_dlpack(intermediate_output[0]), + torch.from_dlpack(intermediate_output[1]), + torch.from_dlpack(intermediate_output[2]), + ] + + @register_fake_op("flashinfer::trtllm_fp8_per_channel_scale_moe") + def _fake_trtllm_fp8_per_channel_scale_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_per_channel_weight_scale: torch.Tensor, + gemm1_per_channel_gate_weight_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_per_channel_weight_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + use_routing_scales_on_input: bool, + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + activation_type: int = ActivationType.Swiglu.value, + ): + seq_len = hidden_states.shape[0] + hidden_size = hidden_states.shape[1] + + return [hidden_states.new_empty([seq_len, hidden_size], dtype=torch.bfloat16)] + @register_custom_op( "flashinfer::trtllm_fp8_block_scale_moe", mutates_args=(""), @@ -2228,6 +2409,7 @@ def _fake_trtllm_mxint4_block_scale_moe( return SimpleNamespace( trtllm_bf16_moe=trtllm_bf16_moe_op, trtllm_fp8_per_tensor_scale_moe=trtllm_fp8_per_tensor_scale_moe_op, + trtllm_fp8_per_channel_scale_moe=trtllm_fp8_per_channel_scale_moe_op, trtllm_fp8_block_scale_moe=trtllm_fp8_block_scale_moe_op, trtllm_fp4_block_scale_moe=trtllm_fp4_block_scale_moe_op, trtllm_mxint4_block_scale_moe=trtllm_mxint4_block_scale_moe_op, @@ -2518,6 +2700,95 @@ def trtllm_fp8_per_tensor_scale_moe( return result +@flashinfer_api +def trtllm_fp8_per_channel_scale_moe( + routing_logits: torch.Tensor, + routing_bias: Optional[torch.Tensor], + hidden_states: torch.Tensor, + gemm1_weights: torch.Tensor, + gemm1_per_channel_weight_scale: torch.Tensor, + gemm1_per_channel_gate_weight_scale: torch.Tensor, + gemm2_weights: torch.Tensor, + gemm2_per_channel_weight_scale: torch.Tensor, + num_experts: int, + top_k: int, + n_group: Optional[int], + topk_group: Optional[int], + intermediate_size: int, + local_expert_offset: int, + local_num_experts: int, + routed_scaling_factor: Optional[float], + use_routing_scales_on_input: bool, + routing_method_type: int = 0, + do_finalize: bool = True, + enable_pdl: Optional[bool] = None, + tune_max_num_tokens: int = 8192, + activation_type: int = ActivationType.Swiglu.value, +) -> Union[List[torch.Tensor], torch.Tensor]: + """FP8 per-channel scale MoE operation. + + Args: + routing_logits: [seq_len, num_experts] tensor of routing logits + routing_bias: [num_experts] tensor of routing bias + hidden_states: [seq_len, hidden_size] tensor of input hidden states + gemm1_weights: [num_experts, 2*intermediate_size, hidden_size] FP8 first layer weights + gemm1_per_channel_weight_scale: [local_num_experts, 2*intermediate_size] per-channel scales for gemm1 + gemm1_per_channel_gate_weight_scale: [local_num_experts, 2*intermediate_size] per-channel gate scales for gemm1 + gemm2_weights: [num_experts, hidden_size, intermediate_size] FP8 second layer weights + gemm2_per_channel_weight_scale: [local_num_experts, hidden_size] per-channel scales for gemm2 + num_experts: Total number of experts + top_k: Number of experts to route to per token + n_group: Number of expert groups + topk_group: Number of groups to consider for top-k routing + intermediate_size: Size of intermediate layer + local_expert_offset: Offset of local experts in global expert space + local_num_experts: Number of experts handled by this device + routed_scaling_factor: Scaling factor for routing + use_routing_scales_on_input: Whether to use routing scales on input + routing_method_type: Type of routing method to use (default: 0) + do_finalize: Whether to finalize the output (default: True). + enable_pdl: Whether to enable Programmatic Dependent Launch (PDL). Auto-enabled for >= sm90. + tune_max_num_tokens(int): Maximum number of tokens for tuning. (default: 8192) + activation_type (int): Type of activation function (default: 3 - Swiglu) + + Returns: + when do_finalize=True, returns the final MoE output. + otherwise, returns the intermediate results (gemm2_output, expert_weights, expanded_idx_to_permuted_idx). + """ + result = get_trtllm_moe_sm100_module().trtllm_fp8_per_channel_scale_moe( + routing_logits, + routing_bias, + hidden_states, + gemm1_weights, + gemm1_per_channel_weight_scale, + gemm1_per_channel_gate_weight_scale, + gemm2_weights, + gemm2_per_channel_weight_scale, + num_experts, + top_k, + n_group, + topk_group, + intermediate_size, + local_expert_offset, + local_num_experts, + routed_scaling_factor, + use_routing_scales_on_input, + routing_method_type, + do_finalize, + enable_pdl, + tune_max_num_tokens, + activation_type, + ) + + if do_finalize: + logger.warning_once( + "the single torch.Tensor return type is deprecated and will be replaced with List[torch.Tensor] in the v0.8.0." + ) + return result[0] + else: + return result + + @flashinfer_api def trtllm_fp8_block_scale_moe( routing_logits: torch.Tensor, diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 46617e5dbd..88f969dab5 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -209,7 +209,8 @@ class Runner { int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, - int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl); + int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl, + float* perChannelWeightScale = nullptr); private: batchedGemm::trtllm::gen::Dtype mDtypeAct; @@ -248,7 +249,7 @@ class Runner { int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, cudaStream_t stream, - int32_t configIndex, bool enable_pdl); + int32_t configIndex, bool enable_pdl, float* perChannelWeightScale = nullptr); private: batchedGemm::trtllm::gen::Dtype mDtypeAct; @@ -313,6 +314,13 @@ struct MoERunnerArgs { float* output1_scales_gate_scalar = nullptr; float* output2_scales_scalar = nullptr; + // Per-channel weight scales + // Passed to TrtllmGenBatchedGemmRunner as perTokensSfB (which maps to kernel mPtrPerTokenSfA + // due to transposeMmaOutput=true) + float* gemm1_per_channel_weight_scale = nullptr; // [2*intermediate_size] for gated acts + float* gemm1_per_channel_gate_weight_scale = nullptr; // [2*intermediate_size] for gated acts + float* gemm2_per_channel_weight_scale = nullptr; // [hidden_size] + // Output: void* output = nullptr; float* output_scale = nullptr; diff --git a/tests/moe/test_trtllm_gen_fused_moe.py b/tests/moe/test_trtllm_gen_fused_moe.py index 887472ddbd..7c8e4dbfdc 100644 --- a/tests/moe/test_trtllm_gen_fused_moe.py +++ b/tests/moe/test_trtllm_gen_fused_moe.py @@ -39,6 +39,7 @@ convert_to_block_layout, trtllm_fp4_block_scale_moe, trtllm_fp8_block_scale_moe, + trtllm_fp8_per_channel_scale_moe, trtllm_fp8_per_tensor_scale_moe, trtllm_bf16_moe, trtllm_mxint4_block_scale_moe, @@ -1367,6 +1368,210 @@ def get_tolerances(self): return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} +# ==================================================================================== +# FP8 Per-Channel Implementation +# ==================================================================================== + + +class FP8PerChannelMoe(Moe): + """FP8 MoE implementation with per-channel weight quantization scaling.""" + + @property + def quant_mode(self) -> QuantMode: + return QuantMode.FP8_PER_CHANNEL + + def quantize_weights(self, gemm1_weights, gemm2_weights, hidden_states_sample): + """Quantize weights to FP8 with per-channel scales (max abs per output column).""" + hidden_states_global_scale = calculate_fp8_global_scale_factor( + hidden_states_sample + ) + + # Per-channel quantization: one scale per output column of each expert + gemm1_weights_quant, gemm1_per_channel_scales = quant_fp8_per_channel_batches( + gemm1_weights + ) + gemm2_weights_quant, gemm2_per_channel_scales = quant_fp8_per_channel_batches( + gemm2_weights + ) + + return { + "hidden_states_scale_global": hidden_states_global_scale, + "gemm1_weights": gemm1_weights_quant, + "gemm1_scales": None, + "gemm1_scales_global": None, + "gemm1_per_channel_scales": gemm1_per_channel_scales, + "gemm2_weights": gemm2_weights_quant, + "gemm2_scales": None, + "gemm2_scales_global": None, + "gemm2_per_channel_scales": gemm2_per_channel_scales, + } + + def quantize_inputs(self, hidden_states, hidden_states_scale_global): + """Quantize hidden states to FP8 per-tensor using pre-computed global scale.""" + hidden_states_quant, _ = quant_fp8_per_tensor( + hidden_states, hidden_states_scale_global + ) + + return { + "hidden_states": hidden_states_quant, + "hidden_states_scale": None, + } + + def prepare_static_weights_for_kernel( + self, + args_dequant, + args, + gemm1_weights_orig, + gemm2_weights_orig, + hidden_size, + intermediate_size, + num_experts, + weight_processing, + ): + """Prepare quantized weights for kernel (done offline with weights).""" + epilogue_tile_m = 128 + + # Reorder rows of W1 for fused gated activation + gemm1_weights_fp8_interleaved = [] + for i in range(num_experts): + if is_gated_activation(args.activation_type): + weights = reorder_rows_for_gated_act_gemm(args.gemm1_weights[i].clone()) + else: + weights = args.gemm1_weights[i].clone() + gemm1_weights_fp8_interleaved.append(weights) + + gemm1_weights_fp8_interleaved = torch.stack( + gemm1_weights_fp8_interleaved + ).reshape( + num_experts, + (2 if is_gated_activation(args.activation_type) else 1) * intermediate_size, + hidden_size, + ) + + # Shuffle weights for transposed mma output + gemm1_weights_fp8_shuffled = [] + gemm2_weights_fp8_shuffled = [] + for i in range(num_experts): + gemm1_weights_fp8_shuffled.append( + shuffle_matrix_a( + gemm1_weights_fp8_interleaved[i].view(torch.uint8), epilogue_tile_m + ) + ) + gemm2_weights_fp8_shuffled.append( + shuffle_matrix_a( + args.gemm2_weights[i].view(torch.uint8), epilogue_tile_m + ) + ) + + gemm1_weights_fp8_shuffled = torch.stack(gemm1_weights_fp8_shuffled).view( + torch.float8_e4m3fn + ) + gemm2_weights_fp8_shuffled = torch.stack(gemm2_weights_fp8_shuffled).view( + torch.float8_e4m3fn + ) + + # NOTE: verify this + # Per-channel scales: [local_num_experts, output_dim] + # For GEMM1: the per-channel scale incorporates the hidden_states global scale + # scaleC = c_global_sf / (weight_channel_scale * hidden_states_global_scale) + # For per-channel, we pass the combined scale as both scale_c and scale_gate + gemm1_per_channel_scales = ( + args.gemm1_per_channel_scales + ) # [num_experts, 2*intermediate_size] + gemm2_per_channel_scales = ( + args.gemm2_per_channel_scales + ) # [num_experts, hidden_size] + + if is_gated_activation(args.activation_type): + # Reorder the per-channel scales to match the interleaved weight layout + gemm1_per_channel_scales_reordered = [] + for i in range(num_experts): + scales_2d = gemm1_per_channel_scales[i].reshape(2, intermediate_size) + reordered = torch.empty_like(scales_2d.reshape(-1)) + reordered[0::2] = scales_2d[0] # activation weights + reordered[1::2] = scales_2d[1] # gate weights + gemm1_per_channel_scales_reordered.append(reordered) + gemm1_per_channel_scales = torch.stack(gemm1_per_channel_scales_reordered) + + # Compute the combined scales for the kernel + # scale_c_fc1 = c_global_sf / (per_channel_weight_scale * hidden_states_global_scale) + scale_c_fc1 = args_dequant.c_global_sf / ( + gemm1_per_channel_scales * args.hidden_states_scale_global + ) + scale_gate_fc1 = 1.0 / ( + gemm1_per_channel_scales * args.hidden_states_scale_global + ) + scale_c_fc2 = 1.0 / (args_dequant.c_global_sf * gemm2_per_channel_scales) + + return { + "gemm1_weights": gemm1_weights_fp8_shuffled, + "gemm2_weights": gemm2_weights_fp8_shuffled, + "scale_c_fc1": scale_c_fc1, + "scale_gate_fc1": scale_gate_fc1, + "scale_c_fc2": scale_c_fc2, + } + + def call_moe( + self, static_data, hidden_states_orig, hidden_states_scale_global, **kwargs + ): + """Call MoE with runtime input quantization + kernel execution.""" + expert_logits = kwargs["expert_logits"] + routing_bias = kwargs["routing_bias"] + num_experts = kwargs["num_experts"] + top_k = kwargs["top_k"] + n_groups = kwargs["n_groups"] + top_k_groups = kwargs["top_k_groups"] + intermediate_size = kwargs["intermediate_size"] + routed_scaling = kwargs["routed_scaling"] + routing_method_type = kwargs["routing_method_type"] + enable_autotune = kwargs.get("enable_autotune", True) + activation_type = kwargs["activation_type"] + + # Quantize to FP8 per-tensor using pre-computed global scale factor + hidden_states_fp8, _ = quant_fp8_per_tensor( + hidden_states_orig, hidden_states_scale_global + ) + + with autotune(enable_autotune): + output = trtllm_fp8_per_channel_scale_moe( + ( + expert_logits.to(torch.bfloat16) + if routing_method_type == RoutingMethodType.Llama4 + else expert_logits + ), + routing_bias, + hidden_states_fp8, + static_data["gemm1_weights"], + static_data["scale_c_fc1"], + static_data["scale_gate_fc1"], + static_data["gemm2_weights"], + static_data["scale_c_fc2"], + num_experts, + top_k, + n_groups, + top_k_groups, + intermediate_size, + 0, + num_experts, + routed_scaling, + routing_method_type + == RoutingMethodType.Llama4, # Use_routing_scales_on_input + routing_method_type, + tune_max_num_tokens=TUNE_MAX_NUM_TOKENS, + activation_type=activation_type, + ) + + return output.to(torch.float) + + def compute_reference(self, args): + """FP8 per-channel reference implementation.""" + return run_moe_reference_per_channel_scale_fp8(args) + + def get_tolerances(self): + """Get FP8 per-channel accuracy tolerances.""" + return {"atol": 0.1, "rtol": 0.85, "percent": 0.925} + + # ==================================================================================== # BF16 Implementation # ==================================================================================== @@ -1537,6 +1742,8 @@ def get_moe_impl(quant_mode: QuantMode): return FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_MXFP8) elif quant_mode == QuantMode.FP8_PER_TENSOR: return FP8PerTensorMoe() + elif quant_mode == QuantMode.FP8_PER_CHANNEL: + return FP8PerChannelMoe() else: return FP4Moe(quant_mode) @@ -1567,6 +1774,8 @@ def __init__( activation_type, gemm1_bias=None, gemm2_bias=None, + gemm1_per_channel_scales=None, + gemm2_per_channel_scales=None, ): self.num_tokens = num_tokens self.num_experts = num_experts @@ -1589,6 +1798,8 @@ def __init__( self.activation_type = activation_type self.gemm1_bias = gemm1_bias self.gemm2_bias = gemm2_bias + self.gemm1_per_channel_scales = gemm1_per_channel_scales + self.gemm2_per_channel_scales = gemm2_per_channel_scales class moe_args_dequant: @@ -1993,6 +2204,26 @@ def quant_fp8_per_tensor_batches(a): return result_a_quant, result_a_scales +def quant_fp8_per_channel(a): + """FP8 per-channel scale quantization.""" + max_abs = a.float().abs().nan_to_num().amax(dim=-1) + per_channel_scales = 448.0 / max_abs.clamp(min=1e-12) + a_fp8 = (a.float() * per_channel_scales.unsqueeze(-1)).to(torch.float8_e4m3fn) + return a_fp8, per_channel_scales + + +def quant_fp8_per_channel_batches(a): + """Batch FP8 per-channel scale quantization.""" + num_batches = a.size(0) + a_quant = [] + a_scales = [] + for i in range(num_batches): + a_fp8, scales = quant_fp8_per_channel(a[i]) + a_quant.append(a_fp8) + a_scales.append(scales) + return torch.stack(a_quant), torch.stack(a_scales) + + def quant_dequant_per_tensor_fp8(a): """FP8 per-tensor quantize-dequantize roundtrip function with centralized global scale factor calculation.""" # Use centralized global scale factor calculation @@ -2156,7 +2387,10 @@ def run_moe_dequant(args, quant_mode: QuantMode): ) activation_output = activation_output.to(torch.float) args.c_global_sf = c_global_sf - elif quant_mode == QuantMode.FP8_PER_TENSOR: + elif ( + quant_mode == QuantMode.FP8_PER_TENSOR + or quant_mode == QuantMode.FP8_PER_CHANNEL + ): activation_output, c_global_sf = quant_dequant_per_tensor_fp8( activation_output.to(torch.bfloat16) ) @@ -2431,6 +2665,46 @@ def run_moe_reference_per_tensor_scale_fp8(args): return run_moe_dequant(args_dequant, QuantMode.FP8_PER_TENSOR), args_dequant +def run_moe_reference_per_channel_scale_fp8(args): + """FP8 per-channel reference implementation.""" + hidden_states_dequant = ( + args.hidden_states.to(torch.float) / args.hidden_states_scale_global + ) + + # Dequantize weights using per-channel scales: w_float = w_fp8 / per_channel_scale + gemm1_weights_dequant = {} + for i in range(args.num_experts): + gemm1_weights_dequant[i] = args.gemm1_weights[i].to( + torch.float + ) / args.gemm1_per_channel_scales[i].unsqueeze(-1) + + gemm2_weights_dequant = {} + for i in range(args.num_experts): + gemm2_weights_dequant[i] = args.gemm2_weights[i].to( + torch.float + ) / args.gemm2_per_channel_scales[i].unsqueeze(-1) + + args_dequant = moe_args_dequant( + args.num_tokens, + args.num_experts, + args.hidden_size, + args.intermediate_size, + args.top_k, + args.padding, + hidden_states_dequant, + args.expert_logits, + gemm1_weights_dequant, + gemm2_weights_dequant, + args.permute_info, + args.use_routing_scales_on_input, + args.activation_type, + gemm1_bias=args.gemm1_bias, + gemm2_bias=args.gemm2_bias, + ) + + return run_moe_dequant(args_dequant, QuantMode.FP8_PER_CHANNEL), args_dequant + + def run_moe_reference_bf16(args): """BF16 reference implementation.""" @@ -2730,6 +3004,8 @@ def run_moe_test( activation_type, gemm1_bias=gemm1_bias, gemm2_bias=gemm2_bias, + gemm1_per_channel_scales=quant_data.get("gemm1_per_channel_scales"), + gemm2_per_channel_scales=quant_data.get("gemm2_per_channel_scales"), ) # Compute reference output @@ -2794,6 +3070,7 @@ def run_moe_test( id="FP8_Block_MxFp8", ), pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(FP8PerChannelMoe(), id="FP8_PerChannel"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_NVFP4_NVFP4), id="NvFP4xNvFP4"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_MXFP8), id="MxFP4xMxFP8"), pytest.param(FP4Moe(quant_mode=QuantMode.FP4_MXFP4_Bf16), id="MxFP4xBf16"), @@ -2965,6 +3242,7 @@ def test_renormalize_routing( "moe_impl", [ pytest.param(FP8PerTensorMoe(), id="FP8_PerTensor"), + pytest.param(FP8PerChannelMoe(), id="FP8_PerChannel"), pytest.param( FP8BlockScaleMoe(fp8_quantization_type=QuantMode.FP8_BLOCK_SCALE_DEEPSEEK), id="FP8_Block_DeepSeek", @@ -3234,6 +3512,7 @@ def test_topk_routing( "moe_impl", [ pytest.param(FP8PerTensorMoe(), id="FP8_Tensor"), + pytest.param(FP8PerChannelMoe(), id="FP8_PerChannel"), ], ) @pytest.mark.parametrize( diff --git a/tests/moe/utils.py b/tests/moe/utils.py index 7c8339cecf..4c821693c9 100644 --- a/tests/moe/utils.py +++ b/tests/moe/utils.py @@ -33,11 +33,13 @@ class QuantMode(IntEnum): FP8_PER_TENSOR = 6 BF16 = 7 MXINT4_BF16_BF16 = 8 + FP8_PER_CHANNEL = 9 NON_GATED_ACTIVATION_SUPPORTED_QUANT_MODES = [ QuantMode.FP4_NVFP4_NVFP4, QuantMode.FP8_PER_TENSOR, + QuantMode.FP8_PER_CHANNEL, ]