Skip to content

Commit 3628a54

Browse files
yongwwwcyx-6
andauthored
Remove getEnvEnablePDL in favor of enable_pdl parameter (#1446)
<!-- .github/pull_request_template.md --> ## 📌 Description <!-- What does this PR do? Briefly describe the changes and why they’re needed. --> ## 🔍 Related Issues <!-- Link any related issues here --> ## 🚀 Pull Request Checklist Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete. ### ✅ Pre-commit Checks - [ ] I have installed `pre-commit` by running `pip install pre-commit` (or used your preferred method). - [ ] I have installed the hooks with `pre-commit install`. - [ ] I have run the hooks manually with `pre-commit run --all-files` and fixed any reported issues. > If you are unsure about how to set up `pre-commit`, see [the pre-commit documentation](https://pre-commit.com/). ## 🧪 Tests - [ ] Tests have been added or updated as needed. - [ ] All tests are passing (`unittest`, etc.). ## Reviewer Notes <!-- Optional: anything you'd like reviewers to focus on, concerns, etc. --> --------- Co-authored-by: Yaxing Cai <[email protected]>
1 parent 0305341 commit 3628a54

25 files changed

+356
-323
lines changed

csrc/fused_moe/cutlass_backend/cutlass_fused_moe_kernels.cuh

Lines changed: 127 additions & 160 deletions
Large diffs are not rendered by default.

csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_ops.cu

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
204204
torch::optional<at::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
205205
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
206206
int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
207-
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids) {
207+
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, bool enable_pdl) {
208208
std::lock_guard<std::mutex> lock(mMutex);
209209
// Free the profile workspace to save memory
210210
freeProfileWorkspace();
@@ -315,7 +315,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
315315
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
316316
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
317317
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
318-
stream);
318+
enable_pdl, stream);
319319
#else
320320
mKernelRunner->runMoe(
321321
input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
@@ -331,7 +331,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
331331
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
332332
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
333333
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
334-
stream);
334+
enable_pdl, stream);
335335
#endif
336336

337337
return output;
@@ -346,7 +346,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
346346
torch::optional<at::Tensor> const& input_sf, int64_t const tp_size, int64_t const tp_rank,
347347
int64_t const ep_size, int64_t const ep_rank, int64_t const cluster_size,
348348
int64_t const cluster_rank, bool const enable_alltoall, bool min_latency_mode,
349-
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids) {
349+
torch::optional<c10::ArrayRef<int64_t>> const& profile_ids, bool enable_pdl) {
350350
std::lock_guard<std::mutex> lock(mMutex);
351351

352352
// Free the profile workspace to save memory
@@ -458,7 +458,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
458458
static_cast<char*>(workspace_info.workspace.data_ptr()), output.data_ptr(),
459459
static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall,
460460
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
461-
stream);
461+
enable_pdl, stream);
462462
#else
463463
mKernelRunner->runMoe(
464464
input.const_data_ptr(), input_sf.has_value() ? input_sf.value().const_data_ptr() : nullptr,
@@ -474,7 +474,7 @@ class FusedMoeRunner : public torch::CustomClassHolder {
474474
static_cast<int>(experts_per_token), static_cast<char*>(workspace_info.workspace),
475475
output.data_ptr(), static_cast<int*>(workspace_info.src_to_dest_map), parallelism_config,
476476
false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params,
477-
stream);
477+
enable_pdl, stream);
478478
#endif
479479

480480
return std::make_tuple(output, num_active_experts_per_node, experts_to_token_score,
@@ -493,7 +493,8 @@ class FusedMoeRunner : public torch::CustomClassHolder {
493493
int64_t const tp_size, int64_t const tp_rank, int64_t const ep_size,
494494
int64_t const ep_rank, int64_t const cluster_size, int64_t const cluster_rank,
495495
bool const enable_alltoall, bool const min_latency_mode,
496-
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation) {
496+
int64_t const gemm_idx, int64_t const profile_id, bool const do_preparation,
497+
bool enable_pdl) {
497498
std::lock_guard<std::mutex> lock(mMutex);
498499

499500
// TODO: support profiling under fp8 block scaling in the future
@@ -558,11 +559,12 @@ class FusedMoeRunner : public torch::CustomClassHolder {
558559
TORCH_CHECK(cu_malloc_status == cudaSuccess,
559560
"Can't allocate profile workspace for MoE GEMM profile.");
560561

561-
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, stream);
562+
mProfiler->prepare(num_rows, mProfileWorkspace, expert_weights_ptr, enable_pdl, stream);
562563
}
563564

564565
// Profile specific tactic. Assuming at least one preparation phase has been executed already.
565-
mProfiler->runProfiler(num_rows, profile, mProfileWorkspace, expert_weights_ptr, stream);
566+
mProfiler->runProfiler(num_rows, profile, mProfileWorkspace, expert_weights_ptr, enable_pdl,
567+
stream);
566568
}
567569

568570
private:

csrc/nv_internal/cpp/common/envUtils.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,6 @@ bool getEnvUseTileSizeKv64ForTrtllmGen() {
189189
return useTileSizeKv64;
190190
}
191191

192-
bool getEnvEnablePDL() {
193-
static std::once_flag flag;
194-
static bool enablePDL = false;
195-
196-
std::call_once(flag, [&]() {
197-
if (getSMVersion() >= 90) {
198-
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
199-
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
200-
}
201-
});
202-
return enablePDL;
203-
}
204-
205192
bool getEnvUseUCXKvCache() {
206193
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
207194
return useUCXKVCache;

csrc/nv_internal/cpp/kernels/quantization.cu

Lines changed: 33 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const
7676
template <typename T>
7777
void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input, int64_t* output,
7878
int32_t* SFOuput, FP4QuantizationSFLayout layout,
79-
int multiProcessorCount, cudaStream_t stream) {
79+
int multiProcessorCount, bool enable_pdl, cudaStream_t stream) {
8080
// Fixed SF_VEC_SIZE as 32
8181
static constexpr int SF_VEC_SIZE = 32;
8282

@@ -95,7 +95,7 @@ void invokeMxFP8Quantization(int b, int m, int n, int padded_n, T const* input,
9595
config.stream = stream;
9696
cudaLaunchAttribute attrs[1];
9797
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
98-
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
98+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
9999
config.numAttrs = 1;
100100
config.attrs = attrs;
101101
cudaLaunchKernelEx(
@@ -168,7 +168,7 @@ INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(__nv_bfloat16, __nv_fp8_e4m3);
168168
template <typename T, int SF_VEC_SIZE>
169169
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, int64_t* output,
170170
int32_t* SFOuput, bool useUE8M0, FP4QuantizationSFLayout layout,
171-
int multiProcessorCount, cudaStream_t stream) {
171+
int multiProcessorCount, bool enable_pdl, cudaStream_t stream) {
172172
#ifdef ENABLE_FP8
173173
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
174174
// Grid, Block size.
@@ -204,7 +204,7 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale, i
204204
config.stream = stream;
205205
cudaLaunchAttribute attrs[1];
206206
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
207-
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
207+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
208208
config.numAttrs = 1;
209209
config.attrs = attrs;
210210
cudaLaunchKernelEx(&config, kernel_instance, m, n, input, SFScale,
@@ -217,7 +217,7 @@ template <typename T, int SF_VEC_SIZE>
217217
void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float const* SFScale,
218218
int64_t* output, int32_t* SFOuput, bool useUE8M0,
219219
int multiProcessorCount, FP4QuantizationSFLayout layout,
220-
cudaStream_t stream) {
220+
bool enable_pdl, cudaStream_t stream) {
221221
#ifdef ENABLE_FP8
222222
if constexpr (std::is_same_v<T, __nv_fp8_e4m3>) {
223223
// Grid, Block size.
@@ -253,7 +253,7 @@ void invokeBatchedFP4Quantization(int b, int m, int n, T const* input, float con
253253
config.stream = stream;
254254
cudaLaunchAttribute attrs[1];
255255
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
256-
attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL();
256+
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl;
257257
config.numAttrs = 1;
258258
config.attrs = attrs;
259259
cudaLaunchKernelEx(&config, kernel_instance, b, m, n, input, SFScale,
@@ -344,47 +344,56 @@ void invokeNVFP4BlockScaleInterleaveReverse(int b, int m, int n, uint8_t const*
344344
template void invokeFP4Quantization<half, 16>(int m, int n, half const* input, float const* SFScale,
345345
int64_t* output, int32_t* SFOuput, bool useUE8M0,
346346
FP4QuantizationSFLayout layout,
347-
int multiProcessorCount, cudaStream_t stream);
347+
int multiProcessorCount, bool enable_pdl,
348+
cudaStream_t stream);
348349
template void invokeFP4Quantization<half, 32>(int m, int n, half const* input, float const* SFScale,
349350
int64_t* output, int32_t* SFOuput, bool useUE8M0,
350351
FP4QuantizationSFLayout layout,
351-
int multiProcessorCount, cudaStream_t stream);
352-
template void invokeBatchedFP4Quantization<half, 16>(
353-
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
354-
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
355-
template void invokeBatchedFP4Quantization<half, 32>(
356-
int b, int m, int n, half const* input, float const* SFScale, int64_t* output, int32_t* SFOuput,
357-
bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout, cudaStream_t stream);
352+
int multiProcessorCount, bool enable_pdl,
353+
cudaStream_t stream);
354+
template void invokeBatchedFP4Quantization<half, 16>(int b, int m, int n, half const* input,
355+
float const* SFScale, int64_t* output,
356+
int32_t* SFOuput, bool useUE8M0,
357+
int multiProcessorCount,
358+
FP4QuantizationSFLayout layout,
359+
bool enable_pdl, cudaStream_t stream);
360+
template void invokeBatchedFP4Quantization<half, 32>(int b, int m, int n, half const* input,
361+
float const* SFScale, int64_t* output,
362+
int32_t* SFOuput, bool useUE8M0,
363+
int multiProcessorCount,
364+
FP4QuantizationSFLayout layout,
365+
bool enable_pdl, cudaStream_t stream);
358366
template void invokeMxFP8Quantization<half>(int b, int m, int n, int padded_n, half const* input,
359367
int64_t* output, int32_t* SFOuput,
360368
FP4QuantizationSFLayout layout, int multiProcessorCount,
361-
cudaStream_t stream);
369+
bool enable_pdl, cudaStream_t stream);
362370
#ifdef ENABLE_BF16
363371
template void invokeFP4Quantization<__nv_bfloat16, 16>(int m, int n, __nv_bfloat16 const* input,
364372
float const* SFScale, int64_t* output,
365373
int32_t* SFOuput, bool useUE8M0,
366374
FP4QuantizationSFLayout layout,
367-
int multiProcessorCount,
375+
int multiProcessorCount, bool enable_pdl,
368376
cudaStream_t stream);
369377
template void invokeFP4Quantization<__nv_bfloat16, 32>(int m, int n, __nv_bfloat16 const* input,
370378
float const* SFScale, int64_t* output,
371379
int32_t* SFOuput, bool useUE8M0,
372380
FP4QuantizationSFLayout layout,
373-
int multiProcessorCount,
381+
int multiProcessorCount, bool enable_pdl,
374382
cudaStream_t stream);
375383
template void invokeBatchedFP4Quantization<__nv_bfloat16, 16>(
376384
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
377385
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
378-
cudaStream_t stream);
386+
bool enable_pdl, cudaStream_t stream);
379387
template void invokeBatchedFP4Quantization<__nv_bfloat16, 32>(
380388
int b, int m, int n, __nv_bfloat16 const* input, float const* SFScale, int64_t* output,
381389
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
382-
cudaStream_t stream);
390+
bool enable_pdl, cudaStream_t stream);
383391
template void invokeMxFP8Quantization<__nv_bfloat16>(int b, int m, int n, int padded_n,
384392
__nv_bfloat16 const* input, int64_t* output,
385393
int32_t* SFOuput,
386394
FP4QuantizationSFLayout layout,
387-
int multiProcessorCount, cudaStream_t stream);
395+
int multiProcessorCount, bool enable_pdl,
396+
cudaStream_t stream);
388397

389398
#endif
390399

@@ -393,22 +402,22 @@ template void invokeFP4Quantization<__nv_fp8_e4m3, 16>(int m, int n, __nv_fp8_e4
393402
float const* SFScale, int64_t* output,
394403
int32_t* SFOuput, bool useUE8M0,
395404
FP4QuantizationSFLayout layout,
396-
int multiProcessorCount,
405+
int multiProcessorCount, bool enable_pdl,
397406
cudaStream_t stream);
398407
template void invokeFP4Quantization<__nv_fp8_e4m3, 32>(int m, int n, __nv_fp8_e4m3 const* input,
399408
float const* SFScale, int64_t* output,
400409
int32_t* SFOuput, bool useUE8M0,
401410
FP4QuantizationSFLayout layout,
402-
int multiProcessorCount,
411+
int multiProcessorCount, bool enable_pdl,
403412
cudaStream_t stream);
404413
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 16>(
405414
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
406415
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
407-
cudaStream_t stream);
416+
bool enable_pdl, cudaStream_t stream);
408417
template void invokeBatchedFP4Quantization<__nv_fp8_e4m3, 32>(
409418
int b, int m, int n, __nv_fp8_e4m3 const* input, float const* SFScale, int64_t* output,
410419
int32_t* SFOuput, bool useUE8M0, int multiProcessorCount, FP4QuantizationSFLayout layout,
411-
cudaStream_t stream);
420+
bool enable_pdl, cudaStream_t stream);
412421
#endif
413422

414423
////////////////////////////////////////////////////////////////////////////////////////////////////

csrc/nv_internal/tensorrt_llm/common/envUtils.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ int getEnvMmhaBlocksPerSequence();
4848

4949
int getEnvMmhaKernelBlockSize();
5050

51-
// Whether PDL is enabled.
52-
bool getEnvEnablePDL();
53-
5451
bool getEnvUseUCXKvCache();
5552

5653
bool getEnvUseMPIKvCache();

csrc/nv_internal/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,9 @@ struct TmaWarpSpecializedGroupedGemmInput {
226226
uint8_t* gemm_workspace = nullptr;
227227
size_t gemm_workspace_size = 0;
228228

229+
// Whether to enable PDL (Programmatic Dependent Launch).
230+
bool enable_pdl;
231+
229232
static std::array<size_t, 17> workspaceBuffers(int num_experts, FpXBlockScalingType scaling_type);
230233

231234
static size_t workspaceSize(int num_experts, FpXBlockScalingType scaling_type);

0 commit comments

Comments
 (0)