diff --git a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh index 6c8789e9a..df1a0ea70 100644 --- a/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh +++ b/csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh @@ -251,15 +251,13 @@ __global__ void buildMinLatencyActiveExpertMapsKernel( #endif } -void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, - float* experts_to_token_scores, int* active_expert_global_ids, - int64_t* expert_first_token_offset, - int const* token_selected_experts, - float const* token_final_scales, int64_t const num_tokens, - int const experts_per_token, int const start_expert, - int const end_expert, int const num_experts_per_node, - int const cluster_rank, int const cluster_size, - int const num_experts_smem, cudaStream_t const stream) { +void buildMinLatencyActiveExpertMaps( + int* num_active_experts_per_node, float* experts_to_token_scores, int* active_expert_global_ids, + int64_t* expert_first_token_offset, int const* token_selected_experts, + float const* token_final_scales, int64_t const num_tokens, int const experts_per_token, + int const start_expert, int const end_expert, int const num_experts_per_node, + int const cluster_rank, int const cluster_size, int const num_experts_smem, bool enable_pdl, + cudaStream_t const stream) { TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert), "num_experts_per_node must be equal to end_expert - start_expert"); @@ -277,7 +275,7 @@ void buildMinLatencyActiveExpertMaps(int* num_active_experts_per_node, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, buildMinLatencyActiveExpertMapsKernel, @@ -380,7 +378,7 @@ bool fusedBuildExpertMapsSortFirstTokenDispatch( int const* token_selected_experts, int* permuted_row_to_unpermuted_row, int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, - int const start_expert, int const end_expert, cudaStream_t stream) { + int const start_expert, int const end_expert, bool enable_pdl, cudaStream_t stream) { TLLM_CHECK_WITH_INFO(num_experts_per_node == (end_expert - start_expert), "num_experts_per_node must be equal to end_expert - start_expert"); int const threads = BLOCK_SIZE; @@ -397,7 +395,7 @@ bool fusedBuildExpertMapsSortFirstTokenDispatch( config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; @@ -431,7 +429,7 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize( int const* token_selected_experts, int* permuted_row_to_unpermuted_row, int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, - int const start_expert, int const end_expert, cudaStream_t stream) { + int const start_expert, int const end_expert, bool enable_pdl, cudaStream_t stream) { int const block_size = num_tokens; if (num_tokens > 256) { TLLM_LOG_TRACE( @@ -451,7 +449,8 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize( return func(token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, - num_experts_per_node, experts_per_token, start_expert, end_expert, stream); + num_experts_per_node, experts_per_token, start_expert, end_expert, enable_pdl, + stream); } template @@ -459,7 +458,7 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize( int const* token_selected_experts, int* permuted_row_to_unpermuted_row, int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, - int const start_expert, int const end_expert, cudaStream_t stream) { + int const start_expert, int const end_expert, bool enable_pdl, cudaStream_t stream) { auto func = &fusedBuildExpertMapsSortFirstTokenBlockSize<1, LOG2_NUM_EXPERTS>; switch (experts_per_token) { case 1: { @@ -490,16 +489,15 @@ bool fusedBuildExpertMapsSortFirstTokenBlockSize( } return func(token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, - num_experts_per_node, experts_per_token, start_expert, end_expert, stream); + num_experts_per_node, experts_per_token, start_expert, end_expert, enable_pdl, + stream); } -bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, - int* permuted_row_to_unpermuted_row, - int* unpermuted_row_to_permuted_row, - int64_t* expert_first_token_offset, - int64_t const num_tokens, int const num_experts_per_node, - int const experts_per_token, int const start_expert, - int const end_expert, cudaStream_t stream) { +bool fusedBuildExpertMapsSortFirstToken( + int const* token_selected_experts, int* permuted_row_to_unpermuted_row, + int* unpermuted_row_to_permuted_row, int64_t* expert_first_token_offset, + int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, + int const start_expert, int const end_expert, bool enable_pdl, cudaStream_t stream) { // We need enough bits to represent [0, num_experts_per_node+1] (inclusive) i.e. // num_experts_per_node + 2 values This is floor(log2(num_experts_per_node+1)) + 1 int expert_log = static_cast(log2(num_experts_per_node + 1)) + 1; @@ -517,7 +515,7 @@ bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, return funcs[expert_log - 1](token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, num_tokens, num_experts_per_node, experts_per_token, start_expert, - end_expert, stream); + end_expert, enable_pdl, stream); } TLLM_LOG_TRACE("Experts per node %d does not have supported fused moe prologues", num_experts_per_node); @@ -589,7 +587,7 @@ void blockExpertPrefixSum(int const* token_selected_experts, int* blocked_expert int* blocked_row_to_unpermuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token, int64_t const num_tokens_per_block, int64_t const num_blocks_per_seq, - int const start_expert_id, cudaStream_t stream) { + int const start_expert_id, bool enable_pdl, cudaStream_t stream) { dim3 const blocks(num_experts_per_node, num_blocks_per_seq); dim3 const threads(num_tokens_per_block); @@ -600,7 +598,7 @@ void blockExpertPrefixSum(int const* token_selected_experts, int* blocked_expert config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; @@ -705,7 +703,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, void globalExpertPrefixSum(int const* blocked_expert_counts, int* blocked_expert_counts_cumsum, int64_t* expert_first_token_offset, int64_t const num_experts_per_node, int64_t const num_tokens_per_block, int64_t const num_blocks_per_seq, - cudaStream_t stream) { + bool enable_pdl, cudaStream_t stream) { int64_t const num_elements = num_experts_per_node * num_blocks_per_seq; cudaLaunchConfig_t config; @@ -715,7 +713,7 @@ void globalExpertPrefixSum(int const* blocked_expert_counts, int* blocked_expert config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; @@ -785,7 +783,7 @@ void mergeExpertPrefixSum(int const* blocked_expert_counts, int const* blocked_e int* permuted_token_selected_experts, int* permuted_row_to_unpermuted_row, int* unpermuted_row_to_permuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_tokens_per_block, - int64_t const num_blocks_per_seq, cudaStream_t stream) { + int64_t const num_blocks_per_seq, bool enable_pdl, cudaStream_t stream) { dim3 const blocks(num_experts_per_node, num_blocks_per_seq); dim3 const threads(num_tokens_per_block); @@ -796,7 +794,7 @@ void mergeExpertPrefixSum(int const* blocked_expert_counts, int const* blocked_e config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; @@ -842,25 +840,27 @@ void threeStepBuildExpertMapsSortFirstToken( int64_t* expert_first_token_offset, int* blocked_expert_counts, int* blocked_expert_counts_cumsum, int* blocked_row_to_unpermuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token, - int const start_expert_id, cudaStream_t stream) { + int const start_expert_id, bool enable_pdl, cudaStream_t stream) { int64_t const num_tokens_per_block = computeNumTokensPerBlock(num_tokens, num_experts_per_node); int64_t const num_blocks_per_seq = tensorrt_llm::common::ceilDiv(num_tokens, num_tokens_per_block); blockExpertPrefixSum(token_selected_experts, blocked_expert_counts, blocked_row_to_unpermuted_row, num_tokens, num_experts_per_node, num_experts_per_token, - num_tokens_per_block, num_blocks_per_seq, start_expert_id, stream); + num_tokens_per_block, num_blocks_per_seq, start_expert_id, enable_pdl, + stream); sync_check_cuda_error(stream); globalExpertPrefixSum(blocked_expert_counts, blocked_expert_counts_cumsum, expert_first_token_offset, num_experts_per_node, num_tokens_per_block, - num_blocks_per_seq, stream); + num_blocks_per_seq, enable_pdl, stream); sync_check_cuda_error(stream); mergeExpertPrefixSum(blocked_expert_counts, blocked_expert_counts_cumsum, blocked_row_to_unpermuted_row, permuted_token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, num_tokens, - num_experts_per_node, num_tokens_per_block, num_blocks_per_seq, stream); + num_experts_per_node, num_tokens_per_block, num_blocks_per_seq, enable_pdl, + stream); } // ============================== Infer GEMM sizes ================================= @@ -1614,17 +1614,15 @@ __global__ void expandInputRowsKernel( } template -void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, - ExpandedActivationsType* permuted_output, - float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, - int64_t const num_rows, int64_t const hidden_size, int const k, - int const num_experts_per_node, QuantParams const& quant_params, - bool use_per_expert_act_scale, - int64_t* expert_first_token_offset, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, - void const* prequant_scales, cudaStream_t stream) { +void expandInputRowsKernelLauncher( + InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, + float const* unpermuted_scales, float* permuted_scales, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, + int const k, int const num_experts_per_node, QuantParams const& quant_params, + bool use_per_expert_act_scale, int64_t* expert_first_token_offset, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, + bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( (std::is_same_v && fc1_act_sf_flat) || @@ -1699,7 +1697,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, @@ -1719,7 +1717,7 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int64_t* expert_first_token_offset, \ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - cudaStream_t stream) + bool enable_pdl, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op // INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -1910,7 +1908,8 @@ void finalizeMoeRoutingKernelLauncher( int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, int64_t const experts_per_token, int64_t const num_experts_per_node, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream) { + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, + cudaStream_t stream) { // Only add bias on rank 0 for tensor parallelism bool const is_rank_0 = parallelism_config.tp_rank == 0; ScaleBiasType const* bias_ptr = is_rank_0 ? bias : nullptr; @@ -1921,7 +1920,7 @@ void finalizeMoeRoutingKernelLauncher( config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; @@ -1968,7 +1967,8 @@ void finalizeMoeRoutingKernelLauncher( int const* expanded_dest_row_to_expanded_source_row, int const* expert_for_source_row, \ int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, \ int64_t const experts_per_token, int64_t const num_experts_per_node, \ - MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream); + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, \ + cudaStream_t stream); // // Instantiate the data types that are used by the external pytorch op // INSTANTIATE_FINALIZE_MOE_ROUTING(half, half, half); @@ -2230,7 +2230,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 int64_t const* expert_first_token_offset, int num_experts_per_node, int64_t inter_size, int64_t expanded_num_tokens, ActivationType activation_type, QuantParams const& quant_params, bool use_per_expert_act_scale, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_act_sf_flat, bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP4 constexpr int64_t min_num_tokens_alignment = @@ -2294,7 +2294,7 @@ void doActivation(T* output, GemmOutputType const* gemm_result, float const* fp8 config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, fn, output, gemm_result, fp8_quant, bias, bias_is_broadcast, @@ -2824,7 +2824,8 @@ void CutlassMoeFCRunner(gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type, quant_params, use_per_expert_act_scale, - nullptr, stream); + nullptr, enable_pdl, stream); sync_check_cuda_error(stream); } @@ -2854,12 +2855,11 @@ void CutlassMoeFCRunner(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, stream); + num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); } template (gemm_output), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type, quant_params, use_per_expert_act_scale, - fc2_fp4_act_flat, stream); + fc2_fp4_act_flat, enable_pdl, stream); sync_check_cuda_error(stream); } else if (use_fp8) { @@ -3043,7 +3044,7 @@ void CutlassMoeFCRunner(intermediate_result), fc2_fp8_quant, fc1_expert_biases, bias_is_broadcast, expert_first_token_offset, num_experts_per_node, inter_size, expanded_num_rows, fc1_activation_type, quant_params, use_per_expert_act_scale, - nullptr, stream); + nullptr, enable_pdl, stream); sync_check_cuda_error(stream); } else if (!is_gated_activation) { @@ -3152,7 +3153,7 @@ void CutlassMoeFCRunner(gemm_output), nullptr, static_cast(fc2_lora), false, expert_first_token_offset, num_experts_per_node, hidden_size, expanded_num_rows, - ActivationType::Identity, {}, false, nullptr, stream); + ActivationType::Identity, {}, false, nullptr, enable_pdl, stream); sync_check_cuda_error(stream); } @@ -3247,13 +3249,13 @@ void CutlassMoeFCRunner(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, stream); + num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); } else if (!using_tma_ws_gemm2) { finalizeMoeRoutingKernelLauncher( static_cast(gemm_output), final_output, fc2_expert_biases, unpermuted_final_scales, unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, expert_first_token_offset, num_rows, hidden_size, k, - num_experts_per_node, parallelism_config, enable_alltoall, stream); + num_experts_per_node, parallelism_config, enable_alltoall, enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -3485,7 +3487,7 @@ void CutlassMoeFCRunner::value || std::is_same::value; static constexpr bool fp8_scales_required = std::is_same::value || @@ -3624,14 +3626,14 @@ void CutlassMoeFCRunner:: TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, - UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, + UnfusedGemmOutputType* gemm1_output, UnfusedGemmOutputType* gemm2_output, bool enable_pdl, cudaStream_t stream) { // Always nullptr layout_info1.ptr_c = nullptr; @@ -3828,7 +3830,7 @@ CutlassMoeFCRunner:: config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance, expert_first_token_offset, layout_info1, @@ -3855,7 +3857,7 @@ CutlassMoeFCRunner:: QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, - cudaStream_t stream) { + bool enable_pdl, cudaStream_t stream) { TLLM_CHECK_WITH_INFO(!use_w4afp8, "W4AFP8 is not supported in low latency mode"); // Always nullptr @@ -3890,7 +3892,7 @@ CutlassMoeFCRunner:: config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( @@ -3919,9 +3921,13 @@ CutlassMoeFCRunner:: ScaleBiasType const* fc2_expert_biases, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_lora, int start_expert, MOEParallelismConfig parallelism_config, - cudaStream_t stream) { + bool enable_pdl, cudaStream_t stream) { auto gemm1_tma_ws_input = tma_ws_grouped_gemm1_input_; auto gemm2_tma_ws_input = tma_ws_grouped_gemm2_input_; + + // Set enable_pdl for both GEMM inputs + gemm1_tma_ws_input.enable_pdl = enable_pdl; + gemm2_tma_ws_input.enable_pdl = enable_pdl; if (!moe_gemm_runner_.isTmaWarpSpecialized(*gemm1_config_) && !moe_gemm_runner_.isTmaWarpSpecialized(*gemm2_config_)) { return std::make_pair(gemm1_tma_ws_input, gemm2_tma_ws_input); @@ -3954,7 +3960,7 @@ CutlassMoeFCRunner:: quant_params, nullptr, nullptr, reinterpret_cast(gemm1_output), reinterpret_cast(fc2_result_), min_latency_params.num_active_experts_per_node, min_latency_params.active_expert_global_ids, - start_expert, stream); + start_expert, enable_pdl, stream); } else { auto gemm1_input = use_prequant_scale_kernel ? smoothed_act_ : permuted_data_; @@ -4010,7 +4016,7 @@ CutlassMoeFCRunner:: quant_params.fp8.dequant_fc2, fc1_fp4_act_scale_, fc2_fp4_act_scale_, quant_params, fc1_expert_biases, fc2_expert_biases, reinterpret_cast(gemm1_output), - reinterpret_cast(fc2_result_), stream); + reinterpret_cast(fc2_result_), enable_pdl, stream); } } @@ -4361,7 +4367,7 @@ std::map> GemmProfilerBackend::getProfile return out_map; } -void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, +void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_char, bool enable_pdl, cudaStream_t stream) { auto workspaces = getProfilerWorkspaces(num_tokens, mSM >= 90); #define GET_WS_PTR_BASE(type, name) \ @@ -4419,7 +4425,7 @@ void GemmProfilerBackend::prepareRouting(int num_tokens, char* workspace_ptr_cha token_selected_experts, permuted_token_selected_experts, permuted_row_to_unpermuted_row, unpermuted_row_to_permuted_row, expert_first_token_offset, blocked_expert_counts, blocked_expert_counts_cumsum, blocked_row_to_unpermuted_row, num_tokens, - mNumExpertsPerNode, mK, start_expert_id, stream); + mNumExpertsPerNode, mK, start_expert_id, enable_pdl, stream); sync_check_cuda_error(stream); } } @@ -4483,7 +4489,8 @@ void GemmProfilerBackend::prepareQuantParams(int num_tokens, char* workspace_ptr } void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr_char, - void const* expert_weights, cudaStream_t stream) { + void const* expert_weights, bool enable_pdl, + cudaStream_t stream) { if (mSM < 90) { return; } @@ -4522,12 +4529,14 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr TmaWarpSpecializedGroupedGemmInput dummy_tma_ws_input; dummy_tma_ws_input.configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, workspaces.at("gemm_workspace").first, mScalingType); + dummy_tma_ws_input.enable_pdl = enable_pdl; // Set enable_pdl for dummy input tma_ws_input_workspace += tma_ws_size; size_t num_expanded_tokens = num_tokens * mK; for (int64_t i = 0; i < NUM_ROUTING_SAMPLES; i++) { mTmaInputCache[i].configureWorkspace(tma_ws_input_workspace, mNumExpertsPerNode, gemm_workspace, workspaces.at("gemm_workspace").first, mScalingType); + mTmaInputCache[i].enable_pdl = enable_pdl; // Set enable_pdl for the profiler tma_ws_input_workspace += tma_ws_size; int64_t* expert_first_token_offset = @@ -4567,7 +4576,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, num_active_experts_per_node, - active_expert_global_ids, 0, stream); + active_expert_global_ids, 0, enable_pdl, stream); } else { std::tie(gemm1_tma_ws_input, gemm2_tma_ws_input) = mInterface->computeStridesTmaWarpSpecializedDispatch( @@ -4576,7 +4585,7 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr mExpertInterSize, mNumExpertsPerNode, input, input, weights_sel, weights_sel, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, nullptr, nullptr, intermediate, intermediate, - stream); + enable_pdl, stream); } sync_check_cuda_error(stream); } @@ -4584,16 +4593,17 @@ void GemmProfilerBackend::prepareTmaWsInputs(int num_tokens, char* workspace_ptr } void GemmProfilerBackend::prepare(int num_tokens, char* workspace_ptr_char, - void const* expert_weights, cudaStream_t stream) { + void const* expert_weights, bool enable_pdl, + cudaStream_t stream) { mAllTacticsSaved = mInterface->getTactics(); mSampleIndex = 0; auto workspace_size = getWorkspaceSize(num_tokens); populateRandomBuffer(workspace_ptr_char, workspace_size, stream); - prepareRouting(num_tokens, workspace_ptr_char, stream); + prepareRouting(num_tokens, workspace_ptr_char, enable_pdl, stream); prepareQuantParams(num_tokens, workspace_ptr_char, stream); - prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, stream); + prepareTmaWsInputs(num_tokens, workspace_ptr_char, expert_weights, enable_pdl, stream); } size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { @@ -4608,7 +4618,7 @@ size_t GemmProfilerBackend::getWorkspaceSize(int maxM) { void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tactic, char* workspace_ptr_char, void const* expert_weights, - cudaStream_t const& stream) { + bool enable_pdl, cudaStream_t const& stream) { int64_t expanded_num_tokens = original_num_tokens * mK; int64_t num_experts_per_node = mNumExpertsPerNode; @@ -4661,72 +4671,29 @@ void GemmProfilerBackend::runProfiler(int original_num_tokens, Config const& tac mInterface->is_profiler = true; if (mGemmToProfile == GemmToProfile::GEMM_1) { - mInterface->gemm1(input, // - output, // - intermediate, // - expert_first_token_offset, // - tma_ws_input_template, // - weights_sel, // - bias, // - expert_first_token_offset + num_experts_per_node, // - mQuantParams.wo.fc1_weight_scales, // - mQuantParams.fp8.dequant_fc1, // + mInterface->gemm1(input, output, intermediate, expert_first_token_offset, tma_ws_input_template, + weights_sel, bias, expert_first_token_offset + num_experts_per_node, + mQuantParams.wo.fc1_weight_scales, mQuantParams.fp8.dequant_fc1, mQuantParams.fp8_mxfp4.fc2.act_global_scale ? mQuantParams.fp8_mxfp4.fc2.act_global_scale - : mQuantParams.fp8.quant_fc2, // - fp4_act_scale_flat, // - fp4_act_scale_flat, // - mQuantParams, // - original_num_tokens, // - expanded_num_tokens, // - mExpertHiddenSize, // - mExpertInterSize, // - num_experts_per_node, // - mActivationType, // - alpha_scale_ptr_array, // - !mUseLora, // - /*use_deepseek_fp8_block_scale=*/false, // - stream, // - tactic, // - mMinLatencyMode, // - num_active_experts_per_node, // - active_expert_global_ids); // + : mQuantParams.fp8.quant_fc2, + fp4_act_scale_flat, fp4_act_scale_flat, mQuantParams, original_num_tokens, + expanded_num_tokens, mExpertHiddenSize, mExpertInterSize, + num_experts_per_node, mActivationType, alpha_scale_ptr_array, !mUseLora, + /*use_deepseek_fp8_block_scale=*/false, stream, tactic, mMinLatencyMode, + num_active_experts_per_node, active_expert_global_ids, enable_pdl); } else { TLLM_CHECK(mGemmToProfile == GemmToProfile::GEMM_2); - mInterface->gemm2(input, // - intermediate, // - output, // - expert_first_token_offset, // - tma_ws_input_template, // - weights_sel, // - bias, // - mQuantParams.wo.fc2_weight_scales, // - mQuantParams.fp8.dequant_fc2, // - fp4_act_scale_flat, // - mQuantParams, // - token_topk_unpermuted_scales, // - token_topk_permuted_scales, // - unpermuted_row_to_permuted_row, // - permuted_row_to_unpermuted_row, // - token_selected_experts, // - expert_first_token_offset + mNumExpertsPerNode, // - original_num_tokens, // - expanded_num_tokens, // - mExpertHiddenSize, // - mExpertInterSize, // - num_experts_per_node, // - mK, // - alpha_scale_ptr_array, // - false, // - nullptr, // - /*use_deepseek_fp8_block_scale=*/false, // - stream, // - mParallelismConfig, // - mEnableAlltoall, // - tactic, // - mMinLatencyMode, // - num_active_experts_per_node, // - active_expert_global_ids); // + mInterface->gemm2( + input, intermediate, output, expert_first_token_offset, tma_ws_input_template, weights_sel, + bias, mQuantParams.wo.fc2_weight_scales, mQuantParams.fp8.dequant_fc2, fp4_act_scale_flat, + mQuantParams, token_topk_unpermuted_scales, token_topk_permuted_scales, + unpermuted_row_to_permuted_row, permuted_row_to_unpermuted_row, token_selected_experts, + expert_first_token_offset + mNumExpertsPerNode, original_num_tokens, expanded_num_tokens, + mExpertHiddenSize, mExpertInterSize, num_experts_per_node, mK, alpha_scale_ptr_array, false, + nullptr, + /*use_deepseek_fp8_block_scale=*/false, stream, mParallelismConfig, mEnableAlltoall, tactic, + mMinLatencyMode, num_active_experts_per_node, active_expert_global_ids, enable_pdl); } mInterface->is_profiler = false; diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu index 8480e09ed..3b269ed4a 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu @@ -204,7 +204,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { torch::optional const& input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, - torch::optional> const& profile_ids) { + torch::optional> const& profile_ids, bool enable_pdl) { std::lock_guard lock(mMutex); // Free the profile workspace to save memory freeProfileWorkspace(); @@ -315,7 +315,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - stream); + enable_pdl, stream); #else mKernelRunner->runMoe( input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, @@ -331,7 +331,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { static_cast(experts_per_token), static_cast(workspace_info.workspace), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - stream); + enable_pdl, stream); #endif return output; @@ -346,7 +346,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { torch::optional const& input_sf, int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode, - torch::optional> const& profile_ids) { + torch::optional> const& profile_ids, bool enable_pdl) { std::lock_guard lock(mMutex); // Free the profile workspace to save memory @@ -458,7 +458,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - stream); + enable_pdl, stream); #else mKernelRunner->runMoe( input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr, @@ -474,7 +474,7 @@ class FusedMoeRunner : public torch::CustomClassHolder { static_cast(experts_per_token), static_cast(workspace_info.workspace), output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, - stream); + enable_pdl, stream); #endif return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score, @@ -493,7 +493,8 @@ class FusedMoeRunner : public torch::CustomClassHolder { int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank, bool const enable_alltoall, bool const min_latency_mode, - int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation) { + int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation, + bool enable_pdl) { std::lock_guard lock(mMutex); // TODO: support profiling under fp8 block scaling in the future @@ -558,11 +559,12 @@ class FusedMoeRunner : public torch::CustomClassHolder { TORCH_CHECK(cu_malloc_status == cudaSuccess, "Can't allocate profile workspace for MoE GEMM profile."); - mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream); + mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, enable_pdl, stream); } // Profile specific tactic. Assuming at least one preparation phase has been executed already. - mProfiler->runProfiler(num_rows, profile, mProfileWorkspace, expert_weights_ptr, stream); + mProfiler->runProfiler(num_rows, profile, mProfileWorkspace, expert_weights_ptr, enable_pdl, + stream); } private: diff --git a/csrc/nv_internal/cpp/common/envUtils.cpp b/csrc/nv_internal/cpp/common/envUtils.cpp index efcd8f492..e2ee31261 100644 --- a/csrc/nv_internal/cpp/common/envUtils.cpp +++ b/csrc/nv_internal/cpp/common/envUtils.cpp @@ -189,19 +189,6 @@ bool getEnvUseTileSizeKv64ForTrtllmGen() { return useTileSizeKv64; } -bool getEnvEnablePDL() { - static std::once_flag flag; - static bool enablePDL = false; - - std::call_once(flag, [&]() { - if (getSMVersion() >= 90) { - // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` - enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); - } - }); - return enablePDL; -} - bool getEnvUseUCXKvCache() { static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE"); return useUCXKVCache; diff --git a/csrc/nv_internal/cpp/kernels/quantization.cu b/csrc/nv_internal/cpp/kernels/quantization.cu index 4297ca2b6..e9ae62830 100644 --- a/csrc/nv_internal/cpp/kernels/quantization.cu +++ b/csrc/nv_internal/cpp/kernels/quantization.cu @@ -76,7 +76,7 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const template void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output, int32_t* SFOuput, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream) { + int multiProcessorCount, bool enable_pdl, cudaStream_t stream) { // Fixed SF_VEC_SIZE as 32 static constexpr int SF_VEC_SIZE = 32; @@ -95,7 +95,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx( @@ -168,7 +168,7 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3); template void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream) { + int multiProcessorCount, bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP8 if constexpr (std::is_same_v) { // Grid, Block size. @@ -204,7 +204,7 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance, m, n, input, SFScale, @@ -217,7 +217,7 @@ template void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, - cudaStream_t stream) { + bool enable_pdl, cudaStream_t stream) { #ifdef ENABLE_FP8 if constexpr (std::is_same_v) { // Grid, Block size. @@ -253,7 +253,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con config.stream = stream; cudaLaunchAttribute attrs[1]; attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); + attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; config.numAttrs = 1; config.attrs = attrs; cudaLaunchKernelEx(&config, kernel_instance, b, m, n, input, SFScale, @@ -344,47 +344,56 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* template void invokeFP4Quantization(int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream); + int multiProcessorCount, bool enable_pdl, + cudaStream_t stream); template void invokeFP4Quantization(int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream); -template void invokeBatchedFP4Quantization( - int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, - bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream); -template void invokeBatchedFP4Quantization( - int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, - bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream); + int multiProcessorCount, bool enable_pdl, + cudaStream_t stream); +template void invokeBatchedFP4Quantization(int b, int m, int n, half const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + FP4QuantizationSFLayout layout, + bool enable_pdl, cudaStream_t stream); +template void invokeBatchedFP4Quantization(int b, int m, int n, half const* input, + float const* SFScale, int64_t* output, + int32_t* SFOuput, bool useUE8M0, + int multiProcessorCount, + FP4QuantizationSFLayout layout, + bool enable_pdl, cudaStream_t stream); template void invokeMxFP8Quantization(int b, int m, int n, int padded_n, half const* input, int64_t* output, int32_t* SFOuput, FP4QuantizationSFLayout layout, int multiProcessorCount, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); #ifdef ENABLE_BF16 template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, + int multiProcessorCount, bool enable_pdl, cudaStream_t stream); template void invokeFP4Quantization<__nv_bfloat16, 32>(int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, + int multiProcessorCount, bool enable_pdl, cudaStream_t stream); template void invokeBatchedFP4Quantization<__nv_bfloat16, 16>( int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>( int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n, __nv_bfloat16 const* input, int64_t* output, int32_t* SFOuput, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream); + int multiProcessorCount, bool enable_pdl, + cudaStream_t stream); #endif @@ -393,22 +402,22 @@ template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(int m, int n, __nv_fp8_e4 float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, + int multiProcessorCount, bool enable_pdl, cudaStream_t stream); template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, + int multiProcessorCount, bool enable_pdl, cudaStream_t stream); template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 16>( int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 32>( int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); #endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/csrc/nv_internal/tensorrt_llm/common/envUtils.h b/csrc/nv_internal/tensorrt_llm/common/envUtils.h index 30ff750e9..887162e78 100644 --- a/csrc/nv_internal/tensorrt_llm/common/envUtils.h +++ b/csrc/nv_internal/tensorrt_llm/common/envUtils.h @@ -48,9 +48,6 @@ int getEnvMmhaBlocksPerSequence(); int getEnvMmhaKernelBlockSize(); -// Whether PDL is enabled. -bool getEnvEnablePDL(); - bool getEnvUseUCXKvCache(); bool getEnvUseMPIKvCache(); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 20d7ec4f8..b2363ccfa 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -226,6 +226,9 @@ struct TmaWarpSpecializedGroupedGemmInput { uint8_t* gemm_workspace = nullptr; size_t gemm_workspace_size = 0; + // Whether to enable PDL (Programmatic Dependent Launch). + bool enable_pdl; + static std::array workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type); static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type); diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index ce39fa52e..43f3eac2f 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -392,7 +392,7 @@ class CutlassMoeFCRunnerInterface { MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, - cudaStream_t stream) = 0; + bool enable_pdl, cudaStream_t stream) = 0; // Aliases for profiling the gemms virtual void gemm1(void const* const input, void* const output, void* const intermediate_result, @@ -409,7 +409,8 @@ class CutlassMoeFCRunnerInterface { ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids) = 0; + int* num_active_experts_per, int* active_expert_global_ids, + bool enable_pdl) = 0; virtual void gemm2(void const* const input, void* const gemm_output, void* const final_output, int64_t const* const expert_first_token_offset, @@ -430,7 +431,7 @@ class CutlassMoeFCRunnerInterface { cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, - int* active_expert_global_ids) = 0; + int* active_expert_global_ids, bool enable_pdl) = 0; virtual std::pair computeStridesTmaWarpSpecializedDispatch( @@ -442,7 +443,7 @@ class CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, cudaStream_t stream) = 0; virtual std::pair @@ -456,7 +457,7 @@ class CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, - cudaStream_t stream) = 0; + bool enable_pdl, cudaStream_t stream) = 0; virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const = 0; @@ -558,7 +559,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + MoeMinLatencyParams& min_latency_params, bool enable_pdl, + cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1( @@ -581,7 +583,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const inter_size, int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids); + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); static void gemm2( MoeGemmRunner& gemm_runner, @@ -601,7 +603,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { float const** alpha_scale_ptr_array, bool use_lora, void* fc2_lora, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids); + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl); // Overrides to allow us to forward on to the internal functions with the pointers using the // correct type @@ -618,7 +620,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids) override { + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) override { auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; return Self::gemm1(moe_gemm_runner_, block_scale_gemm_runner, static_cast(input), @@ -629,7 +631,8 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { fc2_fp8_quant, fc1_fp4_act_flat, fc2_fp4_act_flat, quant_params, num_rows, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, fc1_activation_type, alpha_scale_ptr_array, bias_is_broadcast, stream, - config, min_latency_mode, num_active_experts_per, active_expert_global_ids); + config, min_latency_mode, num_active_experts_per, active_expert_global_ids, + enable_pdl); } void gemm2(void const* const input, void* const gemm_output, void* const final_output, @@ -649,7 +652,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { bool use_deepseek_fp8_block_scale, cudaStream_t stream, MOEParallelismConfig parallelism_config, bool const enable_alltoall, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, - int* num_active_experts_per, int* active_expert_global_ids) override { + int* num_active_experts_per, int* active_expert_global_ids, bool enable_pdl) override { auto* block_scale_gemm_runner = use_deepseek_fp8_block_scale ? getDeepSeekBlockScaleGemmRunner() : nullptr; return Self::gemm2( @@ -663,7 +666,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { num_valid_tokens_ptr, num_rows, expanded_num_rows, hidden_size, inter_size, num_experts_per_node, experts_per_token, alpha_scale_ptr_array, use_lora, fc2_lora, stream, parallelism_config, enable_alltoall, config, min_latency_mode, num_active_experts_per, - active_expert_global_ids); + active_expert_global_ids, enable_pdl); } virtual size_t getGemmWorkspaceSize(int num_experts_per_node) const override { @@ -680,7 +683,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { float const* alpha_scale_flat2, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, - void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, + void const* bias1, void const* bias2, void* gemm1_output, void* gemm2_output, bool enable_pdl, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecialized( expert_first_token_offset, layout_info1, layout_info2, num_tokens, expanded_num_tokens, @@ -691,7 +694,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { fp4_act_flat1, fp4_act_flat2, quant_params, reinterpret_cast(bias1), reinterpret_cast(bias2), reinterpret_cast(gemm1_output), - reinterpret_cast(gemm2_output), stream); + reinterpret_cast(gemm2_output), enable_pdl, stream); } std::pair @@ -705,7 +708,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc2_fp4_act_flat, QuantParams quant_params, void const* bias1, void const* bias2, void* output1, void* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, - cudaStream_t stream) override { + bool enable_pdl, cudaStream_t stream) override { return Self::computeStridesTmaWarpSpecializedLowLatency( layout_info1, layout_info2, num_tokens, gemm1_n, gemm1_k, gemm2_n, gemm2_k, num_experts, reinterpret_cast(input1), reinterpret_cast(input2), @@ -715,7 +718,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { reinterpret_cast(bias2), reinterpret_cast(output1), reinterpret_cast(output2), num_active_experts_per, - active_expert_global_ids, start_expert, stream); + active_expert_global_ids, start_expert, enable_pdl, stream); } private: @@ -731,7 +734,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { ScaleBiasType const* fc2_expert_biases, bool min_latency_mode, MoeMinLatencyParams& min_latency_params, bool use_lora, int start_expert, MOEParallelismConfig parallelism_config, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); static std::pair computeStridesTmaWarpSpecialized( @@ -744,7 +747,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat1, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fp4_act_flat2, QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* gemm1_output, - UnfusedGemmOutputType* gemm2_output, cudaStream_t stream); + UnfusedGemmOutputType* gemm2_output, bool enable_pdl, cudaStream_t stream); static std::pair computeStridesTmaWarpSpecializedLowLatency( TmaWarpSpecializedGroupedGemmInput layout_info1, @@ -757,7 +760,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { QuantParams quant_params, ScaleBiasType const* bias1, ScaleBiasType const* bias2, UnfusedGemmOutputType* output1, UnfusedGemmOutputType* output2, int const* num_active_experts_per, int const* active_expert_global_ids, int start_expert, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); std::map> getWorkspaceDeviceBufferSizes( int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, int const experts_per_token, ActivationType activation_type, @@ -805,13 +808,16 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { DeepSeekBlockScaleGemmRunner* getDeepSeekBlockScaleGemmRunner() const; - static void BlockScaleFC1( - DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, T* const output, - void* const intermediate_result, int64_t const* const expert_first_token_offset, - WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, - float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, - int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); + static void BlockScaleFC1(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, + T* const output, void* const intermediate_result, + int64_t const* const expert_first_token_offset, + WeightType const* const fc1_expert_weights, + ScaleBiasType const* const fc1_expert_biases, + float const* const fc2_fp8_quant, int64_t const num_rows, + int64_t const expanded_num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts_per_node, + ActivationType fc1_activation_type, QuantParams& quant_params, + bool enable_pdl, cudaStream_t stream); static void BlockScaleFC2( DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, @@ -823,7 +829,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface { int64_t const* const num_valid_tokens_ptr, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int64_t const num_experts_per_node, int64_t const k, MOEParallelismConfig parallelism_config, - bool const enable_alltoall, QuantParams& quant_params, cudaStream_t stream); + bool const enable_alltoall, QuantParams& quant_params, bool enable_pdl, cudaStream_t stream); T const* applyPrequantScale(void* smoothed_act, void const* permuted_data, void const* prequant_scales, int64_t const* num_valid_tokens_ptr, @@ -919,13 +925,14 @@ struct GemmProfilerBackend { } } - void prepare(int num_tokens, char* workspace, void const* expert_weights, cudaStream_t stream); + void prepare(int num_tokens, char* workspace, void const* expert_weights, bool enable_pdl, + cudaStream_t stream); std::map> getProfilerWorkspaces(int maxM, bool is_tma_ws); size_t getWorkspaceSize(int maxM); void runProfiler(int num_tokens, Config const& tactic, char* workspace_ptr_char, - void const* expert_weights, cudaStream_t const& stream); + void const* expert_weights, bool enable_pdl, cudaStream_t const& stream); CutlassMoeFCRunnerInterface* mInterface; @@ -962,10 +969,10 @@ struct GemmProfilerBackend { TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType mScalingType{}; private: - void prepareRouting(int num_tokens, char* workspace, cudaStream_t stream); + void prepareRouting(int num_tokens, char* workspace, bool enable_pdl, cudaStream_t stream); void prepareQuantParams(int num_tokens, char* workspace, cudaStream_t stream); void prepareTmaWsInputs(int num_tokens, char* workspace, void const* expert_weights, - cudaStream_t stream); + bool enable_pdl, cudaStream_t stream); }; // Populates a buffer with random values for use with MOE benchmarking diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index 109076d3f..540665ae5 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -47,7 +47,7 @@ bool fusedBuildExpertMapsSortFirstToken(int const* token_selected_experts, int64_t* expert_first_token_offset, int64_t const num_tokens, int const num_experts_per_node, int const experts_per_token, int const start_expert, - int const end_expert, cudaStream_t stream); + int const end_expert, bool enable_pdl, cudaStream_t stream); void threeStepBuildExpertMapsSortFirstToken( int const* token_selected_experts, int* permuted_token_selected_experts, @@ -55,20 +55,18 @@ void threeStepBuildExpertMapsSortFirstToken( int64_t* expert_first_token_offset, int* blocked_expert_counts, int* blocked_expert_counts_cumsum, int* blocked_row_to_unpermuted_row, int64_t const num_tokens, int64_t const num_experts_per_node, int64_t const num_experts_per_token, - int const start_expert_id, cudaStream_t stream); + int const start_expert_id, bool enable_pdl, cudaStream_t stream); template -void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, - ExpandedActivationsType* permuted_output, - float const* unpermuted_scales, float* permuted_scales, - int const* permuted_row_to_unpermuted_row, - int64_t const num_rows, int64_t const hidden_size, int const k, - int const num_experts_per_node, QuantParams const& quant_params, - bool use_per_expert_act_scale, - int64_t* expert_first_token_offset, - TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, - void const* prequant_scales, cudaStream_t stream); +void expandInputRowsKernelLauncher( + InputActivationsType const* unpermuted_input, ExpandedActivationsType* permuted_output, + float const* unpermuted_scales, float* permuted_scales, + int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, + int const k, int const num_experts_per_node, QuantParams const& quant_params, + bool use_per_expert_act_scale, int64_t* expert_first_token_offset, + TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, + bool enable_pdl, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher( @@ -77,7 +75,8 @@ void finalizeMoeRoutingKernelLauncher( int const* permuted_row_to_unpermuted_row, int const* token_selected_experts, int64_t const* expert_first_token_offset, int64_t const num_rows, int64_t const cols, int64_t const experts_per_token, int64_t const num_experts_per_node, - MOEParallelismConfig parallelism_config, bool const enable_alltoall, cudaStream_t stream); + MOEParallelismConfig parallelism_config, bool const enable_alltoall, bool enable_pdl, + cudaStream_t stream); } // namespace cutlass_kernels } // namespace tensorrt_llm::kernels diff --git a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl index 5565aec25..a3e4a8739 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl +++ b/csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_launcher.inl @@ -610,7 +610,7 @@ using SafeBF16 = void; TLLM_CHECK_WITH_INFO(init_status == cutlass::Status::kSuccess, \ "Failed to initialize cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(init_status))); \ - auto run_status = gemm.run(stream, nullptr, tensorrt_llm::common::getEnvEnablePDL()); \ + auto run_status = gemm.run(stream, nullptr, tma_ws_input.enable_pdl); \ TLLM_CHECK_WITH_INFO(run_status == cutlass::Status::kSuccess, \ "Failed to run cutlass TMA WS grouped gemm. Error: " + \ std::string(cutlass::cutlassGetStatusString(run_status))); \ diff --git a/csrc/nv_internal/tensorrt_llm/kernels/quantization.h b/csrc/nv_internal/tensorrt_llm/kernels/quantization.h index 3f48fa891..881b88a03 100644 --- a/csrc/nv_internal/tensorrt_llm/kernels/quantization.h +++ b/csrc/nv_internal/tensorrt_llm/kernels/quantization.h @@ -76,14 +76,15 @@ void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows template void invokeFP4Quantization(int m, int n, T const* input, float const* globalScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream = 0); + int multiProcessorCount, bool enable_pdl = false, + cudaStream_t stream = 0); template void invokeBatchedFP4Quantization( int b, int m, int n, T const* input, float const* globalScale, int64_t* output, int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED_128x4, - cudaStream_t stream = 0); + bool enable_pdl = false, cudaStream_t stream = 0); void invokeNVFP4BlockScaleInterleave(int b, int m, int m_padded, int n, int n_padded, uint8_t const* SFIn, uint8_t* SFOutput, @@ -96,7 +97,8 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const* template void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output, int32_t* SFOuput, FP4QuantizationSFLayout layout, - int multiProcessorCount, cudaStream_t stream = 0); + int multiProcessorCount, bool enable_pdl = false, + cudaStream_t stream = 0); } // namespace kernels } // namespace tensorrt_llm diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp index c231e2347..55e53d2e2 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp @@ -41,7 +41,7 @@ namespace torch_ext { std::tuple fp4_quantize(at::Tensor const& self, at::Tensor const& globalScale, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout, - bool isSf8x4Layout) { + bool isSf8x4Layout, bool enable_pdl) { CHECK_TH_CUDA(self); CHECK_CONTIGUOUS(self); CHECK_INPUT_TYPE(globalScale, c10::ScalarType::Float); @@ -88,7 +88,7 @@ std::tuple fp4_quantize(at::Tensor const& self, m, k, reinterpret_cast(self.data_ptr()), globalScale.data_ptr(), \ reinterpret_cast(valueE2M1.data_ptr()), \ reinterpret_cast(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \ - at::cuda::getCurrentCUDAStream(self.get_device())); + enable_pdl, at::cuda::getCurrentCUDAStream(self.get_device())); if (sfUseUE8M0) { if (self.scalar_type() == at::ScalarType::Half) { diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h index a264f5341..b664219dc 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.h @@ -25,5 +25,5 @@ namespace torch_ext { std::tuple fp4_quantize(at::Tensor const& self, at::Tensor const& globalScale, int64_t sfVecSize, bool sfUseUE8M0, bool isSfSwizzledLayout, - bool isSf8x4Layout); + bool isSf8x4Layout, bool enable_pdl); } // namespace torch_ext diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp index 53ea03b17..0d770d9dc 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp @@ -32,7 +32,7 @@ namespace torch_ext { // alignment: sfVecSize // returns std::tuple mxfp8_quantize(at::Tensor input, bool isSfSwizzledLayout, - int64_t alignment) { + int64_t alignment, bool enable_pdl) { CHECK_TH_CUDA(input); CHECK_CONTIGUOUS(input); @@ -68,11 +68,11 @@ std::tuple mxfp8_quantize(at::Tensor input, bool isSfSwi auto const layout = isSfSwizzledLayout ? tensorrt_llm::FP4QuantizationSFLayout::SWIZZLED_128x4 : tensorrt_llm::FP4QuantizationSFLayout::LINEAR; -#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ - tensorrt_llm::kernels::invokeMxFP8Quantization( \ - 1, m, k, padded_k, reinterpret_cast(input.data_ptr()), \ - reinterpret_cast(valueFP8.data_ptr()), \ - reinterpret_cast(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \ +#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ + tensorrt_llm::kernels::invokeMxFP8Quantization( \ + 1, m, k, padded_k, reinterpret_cast(input.data_ptr()), \ + reinterpret_cast(valueFP8.data_ptr()), \ + reinterpret_cast(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, enable_pdl, \ at::cuda::getCurrentCUDAStream(input.get_device())); if (input.scalar_type() == at::ScalarType::Half) { diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h index e0f0c42c5..96447720e 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.h @@ -64,9 +64,8 @@ inline int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, // linear layout. See FP4QuantizationSFLayout enum for more details about the two layouts. // alignment: sfVecSize // returns fp8_quantized and block_scale_factors. -std::tuple mxfp8_quantize(at::Tensor input, - bool is_sf_swizzled_layout = true, - int64_t alignment = 32); +std::tuple mxfp8_quantize(at::Tensor input, bool is_sf_swizzled_layout, + int64_t alignment, bool enable_pdl); // x_fp32: [M, K], fp32_quantized (on the host) // isSfSwizzledLayout: bool, if true, the scale factors are stored in swizzled layout, otherwise in diff --git a/csrc/trtllm_batched_gemm_runner.cu b/csrc/trtllm_batched_gemm_runner.cu index f8f32c36a..1a31ea545 100644 --- a/csrc/trtllm_batched_gemm_runner.cu +++ b/csrc/trtllm_batched_gemm_runner.cu @@ -158,7 +158,7 @@ void TrtllmGenBatchedGemmRunner::run( float const* ptrClampLimit, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, - CUstream stream, int device, int32_t configIndex) { + CUstream stream, int device, int32_t configIndex, bool enable_pdl) { auto bmm = BatchedGemmInterface(); BatchedGemmData gemmData; @@ -237,9 +237,8 @@ void TrtllmGenBatchedGemmRunner::run( // FIXME once we start using all-reduce in the epilogue of the bmm this can be moved elsewhere bmm.runInitBeforeWorldSync(config, gemmData, static_cast(stream)); - auto const err = - bmm.run(config, workspace, gemmData, static_cast(stream), multiProcessorCount, - tensorrt_llm::common::getEnvEnablePDL(), globalTrtllmGenBatchedGemmModuleCache); + auto const err = bmm.run(config, workspace, gemmData, static_cast(stream), + multiProcessorCount, enable_pdl, globalTrtllmGenBatchedGemmModuleCache); TORCH_CHECK(err == 0, "Error occurred when running GEMM!" @@ -251,7 +250,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, void* workspace, CUstream stream, int device, - int32_t configIndex) { + int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB, @@ -260,7 +259,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, /* ptrBeta */ nullptr, /* ptrClampLimit */ nullptr, c, outSfC, /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, - /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex); + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); } void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, @@ -269,7 +268,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, float const* ptrBias, float const* ptrAlpha, float const* ptrBeta, float const* ptrClampLimit, void* c, void* outSfC, void* workspace, CUstream stream, int device, - int32_t configIndex) { + int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, sfA, b, sfB, @@ -278,14 +277,14 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, outSfC, /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, - /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex); + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); } void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, void* workspace, CUstream stream, int device, - int32_t configIndex) { + int32_t configIndex, bool enable_pdl) { // Dispatch with block scaling factors and with static batching. run(m, n, k, batchedTokens, /* numTokens */ 0, batchedTokens.size(), /* maxNumCtasInBatchDim */ 0, a, @@ -295,7 +294,7 @@ void TrtllmGenBatchedGemmRunner::run(int32_t m, int32_t n, int32_t k, /* outSfC */ nullptr, /* routeMap */ nullptr, /* totalNumPaddedTokens */ nullptr, /* ctaIdxXyToBatchIdx */ nullptr, /* ctaIdxXyToMnLimit */ nullptr, - /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex); + /* numNonExitingCtas */ nullptr, workspace, stream, device, configIndex, enable_pdl); } std::vector TrtllmGenBatchedGemmRunner::getValidConfigIndices( diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index b866c464a..baf727def 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -50,7 +50,7 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher( int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, double const routed_scaling_factor, bool const use_routing_scales_on_input, int64_t const tile_tokens_dim, - int64_t const routing_method_type) { + int64_t const routing_method_type, bool enable_pdl) { auto device = hidden_states.device(); static const std::tuple device_props = [&device] { @@ -282,7 +282,8 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe_launcher( workspace.bmm1_workspace = workspace_fc1.data_ptr(); workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex, + enable_pdl); return output; } @@ -293,7 +294,7 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe( at::Tensor output2_scales_scalar, int64_t num_experts, int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input, - int64_t tile_tokens_dim, int64_t routing_method_type) { + int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) { auto dtype = hidden_states.dtype(); if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 || dtype == at::ScalarType::Float8_e4m3fn) { @@ -301,7 +302,8 @@ at::Tensor trtllm_fp8_per_tensor_scale_moe( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, output1_scales_gate_scalar, gemm2_weights, output2_scales_scalar, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type); + routed_scaling_factor, use_routing_scales_on_input, tile_tokens_dim, routing_method_type, + enable_pdl); } else { TORCH_CHECK(false, "Unsupported input type: ", dtype); } @@ -316,7 +318,8 @@ at::Tensor trtllm_fp8_block_scale_moe_launcher( int64_t const intermediate_size, int64_t const local_expert_offset, int64_t const local_num_experts, double const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, - tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex) { + tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, int64_t moeConfigIndex, + bool enable_pdl) { auto device = hidden_states.device(); static const std::tuple device_props = [&device] { @@ -578,7 +581,8 @@ at::Tensor trtllm_fp8_block_scale_moe_launcher( workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex, + enable_pdl); return output; } @@ -590,7 +594,7 @@ at::Tensor trtllm_fp8_block_scale_moe( int64_t top_k, int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, - int64_t weight_layout) { + int64_t weight_layout, bool enable_pdl) { auto dtype = hidden_states.dtype(); if (dtype == at::ScalarType::Half || dtype == at::ScalarType::BFloat16 || dtype == at::ScalarType::Float8_e4m3fn) { @@ -618,7 +622,8 @@ at::Tensor trtllm_fp8_block_scale_moe( routing_logits, routing_bias, hidden_states, hidden_states_scale, gemm1_weights, gemm1_weights_scale, gemm2_weights, gemm2_weights_scale, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, - routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex); + routed_scaling_factor, tile_tokens_dim, routing_method_type, *mRunner, moeConfigIndex, + enable_pdl); } else { TORCH_CHECK(false, "Unsupported input type: ", dtype); } @@ -644,7 +649,7 @@ std::vector trtllm_fp4_block_scale_moe_launcher( std::optional const routed_scaling_factor, int64_t const tile_tokens_dim, int64_t const routing_method_type, bool const do_finalize, tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner& moe_runner, btg::Dtype dtype_act, - btg::Dtype dtype_weights, int64_t const moeConfigIndex, at::Tensor& output) { + btg::Dtype dtype_weights, int64_t const moeConfigIndex, bool enable_pdl, at::Tensor& output) { auto device = hidden_states.device(); static const std::tuple device_props = [&device] { @@ -1028,7 +1033,8 @@ std::vector trtllm_fp4_block_scale_moe_launcher( workspace.bmm1_workspace = workspace_fc1.data_ptr(); workspace.bmm2_workspace = workspace_fc2.data_ptr(); auto const& moe_stream = at::cuda::getCurrentCUDAStream(hidden_states.get_device()); - moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex); + moe_runner.run(args, workspace, hidden_states.get_device(), moe_stream, moeConfigIndex, + enable_pdl); if (!do_finalize) { return {gemm2_output, expert_weights, expanded_idx_to_permuted_idx}; @@ -1051,7 +1057,7 @@ std::vector trtllm_fp4_block_scale_moe( std::optional n_group, std::optional topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, std::optional routed_scaling_factor, int64_t tile_tokens_dim, - int64_t routing_method_type, bool do_finalize, at::Tensor& output) { + int64_t routing_method_type, bool do_finalize, bool enable_pdl, at::Tensor& output) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; int const num_tokens = hidden_states.sizes()[0]; @@ -1114,7 +1120,7 @@ std::vector trtllm_fp4_block_scale_moe( output1_scales_gate_scalar, output2_scales_scalar, num_experts, top_k, n_group, topk_group, intermediate_size, local_expert_offset, local_num_experts, routed_scaling_factor, tile_tokens_dim, routing_method_type, do_finalize, *mRunner, mDtypeAct, mDtypeWeights, - moeConfigIndex, output); + moeConfigIndex, enable_pdl, output); } namespace trtllm_cubin_loader { diff --git a/csrc/trtllm_fused_moe_runner.cu b/csrc/trtllm_fused_moe_runner.cu index fb5ae89ab..edab492d2 100644 --- a/csrc/trtllm_fused_moe_runner.cu +++ b/csrc/trtllm_fused_moe_runner.cu @@ -215,7 +215,7 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, - int device, cudaStream_t stream, int32_t configIndex) { + int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run(numTokens, 2 * intermediateSize, hiddenSize, {}, numTokens, numExperts, @@ -223,7 +223,7 @@ void Runner::run(void* hiddenState, void* hiddenStateScale, void* weights, void* expertWeights, /* perTokensSfB */ nullptr, outputScalesScalar, outputScalesGateScalar, ptrBias, ptrAlpha, ptrBeta, ptrClampLimit, output, outputScale, permutedIdxToTokenIdx, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, ptrCtaIdxXyToMnLimit, - ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex); + ptrNumNonExitingCtas, bmm1Workspace, stream, device, configIndex, enable_pdl); } size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -300,7 +300,7 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, - cudaStream_t stream, int32_t configIndex) { + cudaStream_t stream, int32_t configIndex, bool enable_pdl) { auto maxNumCtasInBatchDim = Routing::getMaxNumCtasInBatchDim(numTokens, topK, numExperts, mTileTokensDim); mRunner.run( @@ -310,7 +310,8 @@ void Runner::run(void* permutedHiddenState, void* permutedHiddenStateScale, void /* perTokensSfB */ nullptr, outputScalesScalar, /* outputScalesGateScalar */ nullptr, ptrBias, /* ptrAlpha */ nullptr, /* ptrBeta */ nullptr, /* clampLimit */ nullptr, output, outputScale, /* permutedIdxToTokenIdx */ nullptr, ptrTotalNumPaddedTokens, ptrCtaIdxXyToBatchIdx, - ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex); + ptrCtaIdxXyToMnLimit, ptrNumNonExitingCtas, bmm2Workspace, stream, device, configIndex, + enable_pdl); } size_t Runner::getWorkspaceSizeInBytes(int32_t topK, int32_t hiddenSize, int32_t intermediateSize, @@ -483,7 +484,7 @@ int64_t Runner::getDefaultValidConfigIndex(int32_t topK, int32_t hiddenSize, } void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, - cudaStream_t stream, int64_t configIndex) { + cudaStream_t stream, int64_t configIndex, bool enable_pdl) { // Setup all operation data moe::dev::activation::Data activationData; moe::dev::finalize::Data finalizeData; @@ -504,7 +505,7 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d workspace.permuted_idx_to_token_idx, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm1_workspace, - args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config); + args.mUseRoutingScalesOnInput, device, stream, config.gemm1Config, enable_pdl); // We do not fuse activation with FC1 for DeepSeek FP8 due to the weights shuffling constraint. void* gemm2_input = workspace.gemm1_output; @@ -524,7 +525,7 @@ void Runner::run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int d args.local_num_experts, args.num_tokens, workspace.num_non_exiting_ctas, workspace.total_num_padded_tokens, workspace.cta_idx_xy_to_batch_idx, workspace.cta_idx_xy_to_mn_limit, workspace.bmm2_workspace, device, stream, - config.gemm2Config); + config.gemm2Config, enable_pdl); // Run finalize if (args.do_finalize) { diff --git a/flashinfer/fp4_quantization.py b/flashinfer/fp4_quantization.py index a19a8db57..0cdd5c514 100644 --- a/flashinfer/fp4_quantization.py +++ b/flashinfer/fp4_quantization.py @@ -25,6 +25,7 @@ from .jit import env as jit_env from .jit import gen_jit_spec, sm100a_nvcc_flags from .utils import ( + device_support_pdl, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, register_custom_op, @@ -105,6 +106,7 @@ def fp4_quantize_sm100( sf_use_ue8m0: bool = False, is_sf_swizzled_layout: bool = True, is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to FP4 format. @@ -115,12 +117,16 @@ def fp4_quantize_sm100( sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K/2] with dtype FLOAT4_E2M1X2 - Scale factors tensor with shape determined by layout and sf_vec_size """ + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) return module.fp4_quantize( input, global_scale, @@ -128,6 +134,7 @@ def fp4_quantize_sm100( sf_use_ue8m0, is_sf_swizzled_layout, is_sf_8x4_layout, + enable_pdl, ) @register_fake_op("flashinfer::fp4_quantize_sm100") @@ -236,6 +243,7 @@ def fp4_quantize( sf_use_ue8m0: bool = False, is_sf_swizzled_layout: bool = True, is_sf_8x4_layout: bool = False, + enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to FP4 format. @@ -249,6 +257,8 @@ def fp4_quantize( sf_use_ue8m0 (bool, optional): Whether to use UE8M0 format for scale factors. Defaults to False. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. is_sf_8x4_layout (bool, optional): Whether to use 8x4 layout or 128x4 layout for scale factors. Defaults to False. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: @@ -270,6 +280,8 @@ def fp4_quantize( input = input.transpose(-2, -1) assert input.shape[-1] % sf_vec_size == 0 + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) x_q, sf = get_fp4_quantization_sm100_module().fp4_quantize_sm100( input, global_scale, @@ -277,6 +289,7 @@ def fp4_quantize( sf_use_ue8m0, is_sf_swizzled_layout, is_sf_8x4_layout, + enable_pdl, ) sf = sf.reshape((-1, input.shape[-1] // sf_vec_size)) if is_column_major: @@ -390,7 +403,12 @@ class SfLayout(Enum): def nvfp4_quantize( - a, a_global_sf, sfLayout=SfLayout.layout_128x4, do_shuffle=False, sf_vec_size=16 + a, + a_global_sf, + sfLayout=SfLayout.layout_128x4, + do_shuffle=False, + sf_vec_size=16, + enable_pdl=None, ): """ Quantize input tensor to NVFP4 format. @@ -401,6 +419,8 @@ def nvfp4_quantize( sfLayout (SfLayout, optional): Scale factor layout. Defaults to SfLayout.layout_128x4. do_shuffle (bool, optional): Whether to shuffle the scale factors. Defaults to False. Only TRTLLM backend needs to shuffle the tensor B scale factors. sf_vec_size (int, optional): Scale factor vector size. Defaults to 16. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: @@ -417,6 +437,7 @@ def nvfp4_quantize( sf_use_ue8m0=False, is_sf_swizzled_layout=False, is_sf_8x4_layout=False, + enable_pdl=enable_pdl, ) epilogue_tile_m = 128 @@ -434,6 +455,7 @@ def nvfp4_quantize( sf_use_ue8m0=False, is_sf_swizzled_layout=True, is_sf_8x4_layout=sfLayout == SfLayout.layout_8x4, + enable_pdl=enable_pdl, ) return a_fp4, a_sf diff --git a/flashinfer/fp8_quantization.py b/flashinfer/fp8_quantization.py index 65cfda5b2..7aba9e3e1 100644 --- a/flashinfer/fp8_quantization.py +++ b/flashinfer/fp8_quantization.py @@ -1,13 +1,13 @@ import functools from types import SimpleNamespace -from typing import Tuple +from typing import Optional, Tuple import torch from .jit import JitSpec from .jit import env as jit_env from .jit import gen_jit_spec, sm100a_nvcc_flags -from .utils import register_custom_op, register_fake_op +from .utils import device_support_pdl, register_custom_op, register_fake_op def gen_mxfp8_quantization_sm100_module() -> JitSpec: @@ -50,6 +50,7 @@ def mxfp8_quantize_sm100( input: torch.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, + enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to MxFP8 format. @@ -57,6 +58,8 @@ def mxfp8_quantize_sm100( input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. alignment (int, optional): sfVecSize. Defaults to 32. Note that alignment is not used in the host kernel. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 @@ -68,10 +71,13 @@ def mxfp8_quantize_sm100( is_sf_swizzled_layout, ) else: + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) return module.mxfp8_quantize( input, is_sf_swizzled_layout, alignment, + enable_pdl, ) @register_fake_op("flashinfer::mxfp8_quantize_sm100") @@ -129,6 +135,7 @@ def mxfp8_quantize( input: torch.Tensor, is_sf_swizzled_layout: bool = True, alignment: int = 32, + enable_pdl: Optional[bool] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """Quantize input tensor to MxFP8 format. @@ -139,6 +146,8 @@ def mxfp8_quantize( input (torch.Tensor): Input tensor of shape [M, K] with dtype fp16/bf16/fp8_quantized. is_sf_swizzled_layout (bool, optional): Whether to use swizzled layout for scale factors. Defaults to True. alignment (int, optional): sfVecSize. Defaults to 32. + enable_pdl (Optional[bool], optional): Whether to enable PDL (Programmatic Dependent Launch). + If None, automatically detects based on device capability. Defaults to None. Returns: Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - Quantized tensor of shape [M, K] with dtype FLOAT8_E4M3 @@ -147,10 +156,13 @@ def mxfp8_quantize( sf_vec_size = 32 assert input.shape[-1] % sf_vec_size == 0 + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) x_q, sf = get_mxfp8_quantization_sm100_module().mxfp8_quantize_sm100( input, is_sf_swizzled_layout, alignment, + enable_pdl, ) return x_q, sf diff --git a/flashinfer/fused_moe/core.py b/flashinfer/fused_moe/core.py index 34d9914e5..2401cadf4 100644 --- a/flashinfer/fused_moe/core.py +++ b/flashinfer/fused_moe/core.py @@ -36,6 +36,7 @@ from ..jit.cutlass_gemm.generate_kernels import generate_gemm_operations from ..utils import ( _check_shape_dtype_device, + device_support_pdl, get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices, register_custom_op, @@ -359,6 +360,7 @@ def __init__( use_w4a8_group_scaling: bool, use_mxfp8_act_scaling: bool, min_latency_mode: bool, + enable_pdl: bool, ): self.x_dtype = x_dtype self.weight_dtype = weight_dtype @@ -375,6 +377,7 @@ def __init__( self.use_w4a8_group_scaling = use_w4a8_group_scaling self.use_mxfp8_act_scaling = use_mxfp8_act_scaling self.min_latency_mode = min_latency_mode + self.enable_pdl = enable_pdl instance_key = ( x_dtype, weight_dtype, @@ -437,6 +440,7 @@ def forward( gemm_idx, tactic, do_preparation, + self.enable_pdl, ) @classmethod @@ -481,7 +485,10 @@ def cutlass_fused_moe_sm100( use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, + enable_pdl: Optional[bool] = None, ) -> List[torch.Tensor]: + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) tuner = AutoTuner.get() MoERunner.refine_tuning_config(tune_max_num_tokens) @@ -502,6 +509,7 @@ def cutlass_fused_moe_sm100( use_w4a8_group_scaling=use_w4a8_group_scaling, use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, + enable_pdl=enable_pdl, ) _, gemm_tactic_1 = tuner.choose_one( @@ -557,6 +565,7 @@ def cutlass_fused_moe_sm100( enable_alltoall, min_latency_mode, [gemm_tactic_1, gemm_tactic_2], + enable_pdl, ) return result if min_latency_mode else [result] @@ -635,6 +644,7 @@ def cutlass_fused_moe( use_mxfp8_act_scaling: bool = False, min_latency_mode: bool = False, tune_max_num_tokens: int = 8192, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: """Compute a Mixture of Experts (MoE) layer using CUTLASS backend. @@ -763,6 +773,8 @@ def cutlass_fused_moe( raise NotImplementedError("min latency mode not yet implemented for Blackwell.") if use_mxfp8_act_scaling: raise NotImplementedError("mxfp8 not yet implemented for Blackwell.") + if enable_pdl is None: + enable_pdl = device_support_pdl(input.device) num_rows = input.shape[0] if min_latency_mode: @@ -801,6 +813,7 @@ def cutlass_fused_moe( use_mxfp8_act_scaling=use_mxfp8_act_scaling, min_latency_mode=min_latency_mode, tune_max_num_tokens=tune_max_num_tokens, + enable_pdl=enable_pdl, ) @@ -882,7 +895,10 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input: bool, tile_tokens_dim: int = 8, routing_method_type: int = 0, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) # Call the C++ function output = moe_op.trtllm_fp8_per_tensor_scale_moe( routing_logits, @@ -904,6 +920,7 @@ def trtllm_fp8_per_tensor_scale_moe_op( use_routing_scales_on_input, tile_tokens_dim, routing_method_type, + enable_pdl, ) return output @@ -959,7 +976,10 @@ def trtllm_fp8_block_scale_moe_op( routing_method_type: int, use_shuffled_weight: bool = False, weight_layout: int = 0, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) # Call the C++ function for block scale MoE output = moe_op.trtllm_fp8_block_scale_moe( routing_logits, @@ -982,6 +1002,7 @@ def trtllm_fp8_block_scale_moe_op( routing_method_type, use_shuffled_weight, weight_layout, + enable_pdl, ) return output @@ -1048,6 +1069,7 @@ def trtllm_fp4_block_scale_moe_op( tile_tokens_dim: int, routing_method_type: int, do_finalize: bool, + enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: if routing_logits is None: @@ -1072,6 +1094,8 @@ def trtllm_fp4_block_scale_moe_op( expert_weights = torch.empty( num_tokens, top_k, dtype=routing_dtype, device=hidden_states.device ) + if enable_pdl is None: + enable_pdl = device_support_pdl(hidden_states.device) if output is None: output = torch.empty( num_tokens, @@ -1111,6 +1135,7 @@ def trtllm_fp4_block_scale_moe_op( tile_tokens_dim, routing_method_type, do_finalize, + enable_pdl, output, ) @@ -1180,6 +1205,7 @@ def trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input: bool, tile_tokens_dim: int = 8, routing_method_type: int = 0, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: """FP8 per tensor scale MoE operation. @@ -1227,6 +1253,7 @@ def trtllm_fp8_per_tensor_scale_moe( use_routing_scales_on_input, tile_tokens_dim, routing_method_type, + enable_pdl, ) @@ -1251,6 +1278,7 @@ def trtllm_fp8_block_scale_moe( routing_method_type: int = 0, use_shuffled_weight: bool = False, weight_layout: int = 0, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: """FP8 block scale MoE operation. @@ -1298,6 +1326,7 @@ def trtllm_fp8_block_scale_moe( routing_method_type, use_shuffled_weight, weight_layout, + enable_pdl, ) @@ -1441,6 +1470,7 @@ def trtllm_fp4_block_scale_routed_moe( tile_tokens_dim: int = 8, routing_method_type: int = 0, do_finalize: bool = True, + enable_pdl: Optional[bool] = None, output: Optional[torch.Tensor] = None, ) -> List[torch.Tensor]: """FP4 block scale MoE operation. @@ -1523,5 +1553,6 @@ def trtllm_fp4_block_scale_routed_moe( tile_tokens_dim, routing_method_type, do_finalize, + enable_pdl, output, ) diff --git a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h index 4c7d296b9..d8000a8f4 100644 --- a/include/flashinfer/trtllm/batched_gemm/KernelRunner.h +++ b/include/flashinfer/trtllm/batched_gemm/KernelRunner.h @@ -78,22 +78,23 @@ class TrtllmGenBatchedGemmRunner { float const* clampLimit, void* c, void* outSfC, int32_t const* routeMap, int32_t const* totalNumPaddedTokens, int32_t const* ctaIdxXyToBatchIdx, int32_t const* ctaIdxXyToMnLimit, int32_t const* numNonExitingCtas, void* workspace, - CUstream stream, int device, int32_t configIndex); + CUstream stream, int device, int32_t configIndex, bool enable_pdl); // NVFP4 per-block scaling GEMM void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* sfA, void const* b, void const* sfB, void* c, void* outSfC, - void* workspace, CUstream stream, int device, int32_t configIndex); + void* workspace, CUstream stream, int device, int32_t configIndex, bool enable_pdl); void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* sfA, void const* b, void const* sfB, float const* bias, float const* swiGluAlpha, float const* swiGluBeta, float const* clampLimit, void* c, - void* outSfC, void* workspace, CUstream stream, int device, int32_t configIndex); + void* outSfC, void* workspace, CUstream stream, int device, int32_t configIndex, + bool enable_pdl); // FP8 per-tensor scaling GEMM void run(int32_t m, int32_t n, int32_t k, std::vector const& batchedTokens, void const* a, void const* b, float const* scaleC, float const* scaleGateC, void* c, - void* workspace, CUstream stream, int device, int32_t configIndex); + void* workspace, CUstream stream, int device, int32_t configIndex, bool enable_pdl); // Get the list of configs that passed the validation based on the constructor options [[nodiscard]] std::vector getPassingConfigIndices() const { diff --git a/include/flashinfer/trtllm/common.h b/include/flashinfer/trtllm/common.h index 63abd8391..da69d1014 100644 --- a/include/flashinfer/trtllm/common.h +++ b/include/flashinfer/trtllm/common.h @@ -196,19 +196,6 @@ inline bool getEnvUseTileSizeKv64ForTrtllmGen() { static bool const useTileSizeKv64 = getBoolEnv("TRTLLM_GEN_ENABLE_TILE_SIZE_KV64"); return useTileSizeKv64; } - -inline bool getEnvEnablePDL() { - static std::once_flag flag; - static bool enablePDL = false; - - std::call_once(flag, [&]() { - if (getSMVersion() >= 90) { - // PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1` - enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL"); - } - }); - return enablePDL; -} template inline __device__ __host__ T divUp(T m, T n) { return (m + n - 1) / n; diff --git a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh index c0e34f6b4..4cfd65880 100644 --- a/include/flashinfer/trtllm/fmha/fmhaKernels.cuh +++ b/include/flashinfer/trtllm/fmha/fmhaKernels.cuh @@ -214,7 +214,7 @@ class TllmGenFmhaKernel { clusterDimX > 1 ? CU_CLUSTER_SCHEDULING_POLICY_SPREAD : CU_CLUSTER_SCHEDULING_POLICY_DEFAULT; launch_attribute[2].id = CU_LAUNCH_ATTRIBUTE_PROGRAMMATIC_STREAM_SERIALIZATION; - launch_attribute[2].value.programmaticStreamSerializationAllowed = getEnvEnablePDL(); + launch_attribute[2].value.programmaticStreamSerializationAllowed = params.enable_pdl; launch_config.attrs = launch_attribute; launch_config.numAttrs = 3; diff --git a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h index 738b2a3f7..ecbeb48f1 100644 --- a/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h +++ b/include/flashinfer/trtllm/fmha/fmhaRunnerParams.h @@ -275,6 +275,8 @@ struct TllmGenFmhaRunnerParams { float mScaleSfO; // The cuda stream. cudaStream_t stream; + // Whether to enable PDL (Programmatic Dependent Launch). + bool enable_pdl; // set the attention mask type TllmGenFmhaRunnerParams& setAttentionMaskType(std::int8_t maskType) { diff --git a/include/flashinfer/trtllm/fused_moe/runner.h b/include/flashinfer/trtllm/fused_moe/runner.h index 972b8d0bb..4c90b7a07 100644 --- a/include/flashinfer/trtllm/fused_moe/runner.h +++ b/include/flashinfer/trtllm/fused_moe/runner.h @@ -146,7 +146,7 @@ class Runner { int32_t* permutedIdxToTokenIdx, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm1Workspace, bool useRoutingScalesOnInput, - int device, cudaStream_t stream, int32_t configIndex); + int device, cudaStream_t stream, int32_t configIndex, bool enable_pdl); private: batchedGemm::trtllm::gen::Dtype mDtypeAct; @@ -184,7 +184,7 @@ class Runner { int32_t numExperts, int32_t numTokens, int32_t* ptrNumNonExitingCtas, int32_t* ptrTotalNumPaddedTokens, int32_t* ptrCtaIdxXyToBatchIdx, int32_t* ptrCtaIdxXyToMnLimit, void* bmm2Workspace, int device, cudaStream_t stream, - int32_t configIndex); + int32_t configIndex, bool enable_pdl); private: batchedGemm::trtllm::gen::Dtype mDtypeAct; @@ -319,7 +319,7 @@ class Runner { batchedGemm::gemm::MatrixLayout weight_layout = batchedGemm::gemm::MatrixLayout::MajorK); void run(MoERunnerArgs const& args, MoEWorkspace const& workspace, int device, - cudaStream_t stream, int64_t configIndex); + cudaStream_t stream, int64_t configIndex, bool enable_pdl); [[nodiscard]] std::tuple getWorkspaceSizeInBytes(MoERunnerArgs const& args, int64_t configIndex) const;