Skip to content

Commit 8462cf6

Browse files
authored
[TRTLLM-9578][feat] make PDL enabled by default (#9695)
Signed-off-by: Zhenhuan Chen <[email protected]>
1 parent 97b38ac commit 8462cf6

File tree

18 files changed

+68
-57
lines changed

18 files changed

+68
-57
lines changed

cpp/tensorrt_llm/common/cudaFp8Utils.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
4343
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
4444
{
4545
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
46-
asm volatile("griddepcontrol.wait;");
46+
cudaGridDependencySynchronize();
4747
#endif
4848

4949
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
@@ -63,7 +63,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
6363
}
6464
}
6565
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
66-
asm volatile("griddepcontrol.launch_dependents;");
66+
cudaTriggerProgrammaticLaunchCompletion();
6767
#endif
6868
}
6969

cpp/tensorrt_llm/common/envUtils.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,15 +249,26 @@ bool getEnvUseTileSizeKv64ForTrtllmGen()
249249
bool getEnvEnablePDL()
250250
{
251251
static std::once_flag flag;
252-
static bool enablePDL = false;
252+
static bool enablePDL = true;
253253

254254
std::call_once(flag,
255255
[&]()
256256
{
257257
if (getSMVersion() >= 90)
258258
{
259259
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
260-
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
260+
char const* env = std::getenv("TRTLLM_ENABLE_PDL");
261+
if (env)
262+
{
263+
if (env[0] == '1' && env[1] == '\0')
264+
{
265+
enablePDL = true;
266+
}
267+
else if (env[0] == '0' && env[1] == '\0')
268+
{
269+
enablePDL = false;
270+
}
271+
};
261272
}
262273
});
263274
return enablePDL;

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ CUTLASS_DEVICE
4646
void launch_dependent_grids()
4747
{
4848
#if (defined(CUTLASS_GDC_ENABLED))
49-
asm volatile("griddepcontrol.launch_dependents;");
49+
cudaTriggerProgrammaticLaunchCompletion();
5050
#endif
5151
}
5252

@@ -57,7 +57,7 @@ CUTLASS_DEVICE
5757
void wait_on_dependent_grids()
5858
{
5959
#if (defined(CUTLASS_GDC_ENABLED))
60-
asm volatile("griddepcontrol.wait;");
60+
cudaGridDependencySynchronize();
6161
#endif
6262
}
6363

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_pe
164164
int const cluster_size, int const num_experts_smem)
165165
{
166166
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
167-
asm volatile("griddepcontrol.wait;");
167+
cudaGridDependencySynchronize();
168168
#endif
169169
// Use one block to process the min latency case
170170
int tid = threadIdx.x;
@@ -274,7 +274,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_pe
274274
}
275275
}
276276
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
277-
asm volatile("griddepcontrol.launch_dependents;");
277+
cudaTriggerProgrammaticLaunchCompletion();
278278
#endif
279279
}
280280

@@ -333,7 +333,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_
333333

334334
// Wait PDL before reading token_selected_experts
335335
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
336-
asm volatile("griddepcontrol.wait;");
336+
cudaGridDependencySynchronize();
337337
#endif
338338

339339
// build expert map
@@ -374,7 +374,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_
374374

375375
// We are done with compute, launch the dependent kernels while the stores are in flight
376376
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
377-
asm volatile("griddepcontrol.launch_dependents;");
377+
cudaTriggerProgrammaticLaunchCompletion();
378378
#endif
379379

380380
// write to shared memory and global memory
@@ -579,7 +579,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts, in
579579
int const token_id = block_id * kNumTokensPerBlock + threadIdx.x;
580580

581581
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
582-
asm volatile("griddepcontrol.wait;");
582+
cudaGridDependencySynchronize();
583583
#endif
584584

585585
int expanded_token_id = -1;
@@ -612,7 +612,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts, in
612612
}
613613

614614
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
615-
asm volatile("griddepcontrol.launch_dependents;");
615+
cudaTriggerProgrammaticLaunchCompletion();
616616
#endif
617617
}
618618

@@ -672,7 +672,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
672672
int cnt = 0;
673673

674674
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
675-
asm volatile("griddepcontrol.wait;");
675+
cudaGridDependencySynchronize();
676676
#endif
677677

678678
// Note: Because of limited registers, cannot store thread-level prefix sum or enable #pragma unroll
@@ -706,7 +706,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
706706
}
707707

708708
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
709-
asm volatile("griddepcontrol.launch_dependents;");
709+
cudaTriggerProgrammaticLaunchCompletion();
710710
#endif
711711
}
712712

@@ -718,7 +718,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, in
718718
__shared__ typename BlockScan::TempStorage temp_storage;
719719

720720
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
721-
asm volatile("griddepcontrol.wait;");
721+
cudaGridDependencySynchronize();
722722
#endif
723723

724724
int const cnt = threadIdx.x < num_experts_per_node * num_blocks_per_seq ? blocked_expert_counts[threadIdx.x] : 0;
@@ -739,7 +739,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, in
739739
}
740740

741741
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
742-
asm volatile("griddepcontrol.launch_dependents;");
742+
cudaTriggerProgrammaticLaunchCompletion();
743743
#endif
744744
}
745745

@@ -810,7 +810,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts, int
810810
int const token_id = block_id * blockDim.x + threadIdx.x;
811811

812812
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
813-
asm volatile("griddepcontrol.wait;");
813+
cudaGridDependencySynchronize();
814814
#endif
815815

816816
int const cnt = blocked_expert_counts[target_expert_id * num_blocks_per_seq + block_id];
@@ -825,7 +825,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts, int
825825
}
826826

827827
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
828-
asm volatile("griddepcontrol.launch_dependents;");
828+
cudaTriggerProgrammaticLaunchCompletion();
829829
#endif
830830
}
831831

@@ -1259,7 +1259,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
12591259
}
12601260

12611261
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1262-
asm volatile("griddepcontrol.wait;");
1262+
cudaGridDependencySynchronize();
12631263
#endif
12641264

12651265
// Both gemms use the same token offset
@@ -1334,7 +1334,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
13341334
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
13351335

13361336
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1337-
asm volatile("griddepcontrol.launch_dependents;");
1337+
cudaTriggerProgrammaticLaunchCompletion();
13381338
#endif
13391339
}
13401340

@@ -1395,7 +1395,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13951395
"Only NVFP4, MXFP8 and WINT4_AFP8 supports outputting a different format as part of the expansion");
13961396

13971397
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1398-
asm volatile("griddepcontrol.wait;");
1398+
cudaGridDependencySynchronize();
13991399
#endif
14001400

14011401
constexpr int VecSize = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
@@ -1525,7 +1525,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
15251525
}
15261526

15271527
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1528-
asm volatile("griddepcontrol.launch_dependents;");
1528+
cudaTriggerProgrammaticLaunchCompletion();
15291529
#endif
15301530

15311531
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF
@@ -1717,7 +1717,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
17171717
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
17181718

17191719
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1720-
asm volatile("griddepcontrol.wait;");
1720+
cudaGridDependencySynchronize();
17211721
#endif
17221722

17231723
#pragma unroll
@@ -1757,7 +1757,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
17571757
reduced_row_ptr_v[elem_index] = output_elem;
17581758
}
17591759
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1760-
asm volatile("griddepcontrol.launch_dependents;");
1760+
cudaTriggerProgrammaticLaunchCompletion();
17611761
#endif
17621762
}
17631763

@@ -1776,7 +1776,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded
17761776
assert(unpadded_cols <= padded_cols);
17771777

17781778
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1779-
asm volatile("griddepcontrol.wait;");
1779+
cudaGridDependencySynchronize();
17801780
#endif
17811781

17821782
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
@@ -1865,7 +1865,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded
18651865
}
18661866
}
18671867
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1868-
asm volatile("griddepcontrol.launch_dependents;");
1868+
cudaTriggerProgrammaticLaunchCompletion();
18691869
#endif
18701870
}
18711871

@@ -2062,7 +2062,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
20622062
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
20632063

20642064
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2065-
asm volatile("griddepcontrol.wait;");
2065+
cudaGridDependencySynchronize();
20662066
#endif
20672067
for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x)
20682068
{
@@ -2178,7 +2178,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21782178
}
21792179

21802180
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2181-
asm volatile("griddepcontrol.launch_dependents;");
2181+
cudaTriggerProgrammaticLaunchCompletion();
21822182
#endif
21832183

21842184
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF

cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ struct LowLatencyLayerNorm
178178
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
179179
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
180180
{
181-
asm volatile("griddepcontrol.wait;\n");
182-
asm volatile("griddepcontrol.launch_dependents;\n");
181+
cudaGridDependencySynchronize();
182+
cudaTriggerProgrammaticLaunchCompletion();
183183
}
184184
#endif
185185
load_to_register(&param.input[work_id * param.n], data, param.n);

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ struct WarpSpecializedLayerNorm
211211

212212
if constexpr (FIRST_RUN)
213213
{
214-
asm volatile("griddepcontrol.wait;\n");
214+
cudaGridDependencySynchronize();
215215
}
216216

217217
for (int i = 0; i < Traits::M_BLOCK; i++)
@@ -817,7 +817,7 @@ struct WarpSpecializedLayerNorm
817817
{
818818
scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared);
819819
// PRE-EXIT after all tiles have been scheduled.
820-
asm volatile("griddepcontrol.launch_dependents;\n");
820+
cudaTriggerProgrammaticLaunchCompletion();
821821
}
822822
else if (warp_id == 1)
823823
{

cpp/tensorrt_llm/kernels/groupRmsNormKernels/groupRmsNormKernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ __global__ void GroupRMSNormBaseKernel(GroupRMSParams<n> params, int rounds)
111111
PackedType const* __restrict__ weight_ptr = nullptr;
112112

113113
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
114-
asm volatile("griddepcontrol.wait;");
114+
cudaGridDependencySynchronize();
115115
#endif
116116

117117
// Find which input current warp operates on
@@ -263,7 +263,7 @@ __global__ void GroupRMSNormBaseKernel(GroupRMSParams<n> params, int rounds)
263263
}
264264

265265
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
266-
asm volatile("griddepcontrol.launch_dependents;");
266+
cudaTriggerProgrammaticLaunchCompletion();
267267
#endif
268268
}
269269

@@ -305,7 +305,7 @@ __global__ void GroupRMSNormKernelLargeBatch(
305305
bool process_input_1 = warp_idx < warp_size_1;
306306

307307
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
308-
asm volatile("griddepcontrol.wait;");
308+
cudaGridDependencySynchronize();
309309
#endif
310310

311311
// Get input pointers
@@ -565,7 +565,7 @@ __global__ void GroupRMSNormKernelLargeBatch(
565565
}
566566

567567
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
568-
asm volatile("griddepcontrol.launch_dependents;");
568+
cudaTriggerProgrammaticLaunchCompletion();
569569
#endif
570570
}
571571

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Bf16Bf16Gemm.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ __global__ void llama4_bf16_bf16_gemm_kernel(int num_tokens,
6060
b_vec[chunk] = reinterpret_cast<aligned_bf16x4 const*>(B)[row * GEMM_K / VEC_SIZE + base_idx];
6161
}
6262

63-
asm volatile("griddepcontrol.wait;" ::: "memory");
63+
cudaGridDependencySynchronize();
6464

6565
// Process 5 chunks of 4 elements each
6666
#pragma unroll

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmAttnScalingPerBlockTemplate.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_attn_scaling_
100100
#endif
101101

102102
#if ENABLE_ACQBULK
103-
asm volatile("griddepcontrol.wait;" ::: "memory");
103+
cudaGridDependencySynchronize();
104104
#endif
105105

106106
// Processing 8 elements each
@@ -250,7 +250,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_attn_scaling_
250250
}
251251

252252
#if ENABLE_PREEXIT
253-
asm volatile("griddepcontrol.launch_dependents;");
253+
cudaTriggerProgrammaticLaunchCompletion();
254254
#endif
255255
#endif
256256
}

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmPerBlockTemplate.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_per_block_ker
8989
#endif
9090

9191
#if ENABLE_ACQBULK
92-
asm volatile("griddepcontrol.wait;" ::: "memory");
92+
cudaGridDependencySynchronize();
9393
#endif
9494

9595
// Processing 8 elements each
@@ -237,7 +237,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_per_block_ker
237237
}
238238

239239
#if ENABLE_PREEXIT
240-
asm volatile("griddepcontrol.launch_dependents;");
240+
cudaTriggerProgrammaticLaunchCompletion();
241241
#endif
242242
#endif
243243
}

0 commit comments

Comments
 (0)