Skip to content

Commit 42c1856

Browse files
committed
unify PDL's ACKBLK and PREEXIT with CUDA API
Signed-off-by: Zhenhuan Chen <zhenhuanc@nvidia.com>
1 parent b3d468f commit 42c1856

File tree

12 files changed

+42
-42
lines changed

12 files changed

+42
-42
lines changed

cpp/tensorrt_llm/common/cudaFp8Utils.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ template <QuantizeMode QUANTIZE_MODE, bool QUANTIZE, typename T_OUT, typename T_
4242
__global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* input, int64_t numel, int64_t lda)
4343
{
4444
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
45-
asm volatile("griddepcontrol.wait;");
45+
cudaGridDependencySynchronize();
4646
#endif
4747

4848
for (int64_t i = threadIdx.x + blockIdx.x * blockDim.x; i < numel; i += blockDim.x * gridDim.x)
@@ -62,7 +62,7 @@ __global__ void scaleMatrix(T_OUT* output, T_S const* input_scale, T_IN const* i
6262
}
6363
}
6464
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
65-
asm volatile("griddepcontrol.launch_dependents;");
65+
cudaTriggerProgrammaticLaunchCompletion();
6666
#endif
6767
}
6868

cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/grid_dependency_control.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ CUTLASS_DEVICE
4646
void launch_dependent_grids()
4747
{
4848
#if (defined(CUTLASS_GDC_ENABLED))
49-
asm volatile("griddepcontrol.launch_dependents;");
49+
cudaTriggerProgrammaticLaunchCompletion();
5050
#endif
5151
}
5252

@@ -57,7 +57,7 @@ CUTLASS_DEVICE
5757
void wait_on_dependent_grids()
5858
{
5959
#if (defined(CUTLASS_GDC_ENABLED))
60-
asm volatile("griddepcontrol.wait;");
60+
cudaGridDependencySynchronize();
6161
#endif
6262
}
6363

cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_pe
161161
int const cluster_size, int const num_experts_smem)
162162
{
163163
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
164-
asm volatile("griddepcontrol.wait;");
164+
cudaGridDependencySynchronize();
165165
#endif
166166
// Use one block to process the min latency case
167167
int tid = threadIdx.x;
@@ -271,7 +271,7 @@ __global__ void buildMinLatencyActiveExpertMapsKernel(int* num_active_experts_pe
271271
}
272272
}
273273
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
274-
asm volatile("griddepcontrol.launch_dependents;");
274+
cudaTriggerProgrammaticLaunchCompletion();
275275
#endif
276276
}
277277

@@ -330,7 +330,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_
330330

331331
// Wait PDL before reading token_selected_experts
332332
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
333-
asm volatile("griddepcontrol.wait;");
333+
cudaGridDependencySynchronize();
334334
#endif
335335

336336
// build expert map
@@ -371,7 +371,7 @@ __global__ void fusedBuildExpertMapsSortFirstTokenKernel(int const* const token_
371371

372372
// We are done with compute, launch the dependent kernels while the stores are in flight
373373
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
374-
asm volatile("griddepcontrol.launch_dependents;");
374+
cudaTriggerProgrammaticLaunchCompletion();
375375
#endif
376376

377377
// write to shared memory and global memory
@@ -576,7 +576,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts, in
576576
int const token_id = block_id * kNumTokensPerBlock + threadIdx.x;
577577

578578
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
579-
asm volatile("griddepcontrol.wait;");
579+
cudaGridDependencySynchronize();
580580
#endif
581581

582582
int expanded_token_id = -1;
@@ -609,7 +609,7 @@ __global__ void blockExpertPrefixSumKernel(int const* token_selected_experts, in
609609
}
610610

611611
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
612-
asm volatile("griddepcontrol.launch_dependents;");
612+
cudaTriggerProgrammaticLaunchCompletion();
613613
#endif
614614
}
615615

@@ -669,7 +669,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
669669
int cnt = 0;
670670

671671
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
672-
asm volatile("griddepcontrol.wait;");
672+
cudaGridDependencySynchronize();
673673
#endif
674674

675675
// Note: Because of limited registers, cannot store thread-level prefix sum or enable #pragma unroll
@@ -703,7 +703,7 @@ __global__ void globalExpertPrefixSumLargeKernel(int const* blocked_expert_count
703703
}
704704

705705
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
706-
asm volatile("griddepcontrol.launch_dependents;");
706+
cudaTriggerProgrammaticLaunchCompletion();
707707
#endif
708708
}
709709

@@ -715,7 +715,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, in
715715
__shared__ typename BlockScan::TempStorage temp_storage;
716716

717717
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
718-
asm volatile("griddepcontrol.wait;");
718+
cudaGridDependencySynchronize();
719719
#endif
720720

721721
int const cnt = threadIdx.x < num_experts_per_node * num_blocks_per_seq ? blocked_expert_counts[threadIdx.x] : 0;
@@ -736,7 +736,7 @@ __global__ void globalExpertPrefixSumKernel(int const* blocked_expert_counts, in
736736
}
737737

738738
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
739-
asm volatile("griddepcontrol.launch_dependents;");
739+
cudaTriggerProgrammaticLaunchCompletion();
740740
#endif
741741
}
742742

@@ -807,7 +807,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts, int
807807
int const token_id = block_id * blockDim.x + threadIdx.x;
808808

809809
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
810-
asm volatile("griddepcontrol.wait;");
810+
cudaGridDependencySynchronize();
811811
#endif
812812

813813
int const cnt = blocked_expert_counts[target_expert_id * num_blocks_per_seq + block_id];
@@ -822,7 +822,7 @@ __global__ void mergeExpertPrefixSumKernel(int const* blocked_expert_counts, int
822822
}
823823

824824
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
825-
asm volatile("griddepcontrol.launch_dependents;");
825+
cudaTriggerProgrammaticLaunchCompletion();
826826
#endif
827827
}
828828

@@ -1256,7 +1256,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
12561256
}
12571257

12581258
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1259-
asm volatile("griddepcontrol.wait;");
1259+
cudaGridDependencySynchronize();
12601260
#endif
12611261

12621262
// Both gemms use the same token offset
@@ -1331,7 +1331,7 @@ __global__ void computeStridesTmaWarpSpecializedKernel(int64_t const* expert_fir
13311331
bias2, gemm2_output, router_scales, permuted_row_to_unpermuted_row, expert);
13321332

13331333
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1334-
asm volatile("griddepcontrol.launch_dependents;");
1334+
cudaTriggerProgrammaticLaunchCompletion();
13351335
#endif
13361336
}
13371337

@@ -1392,7 +1392,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
13921392
"Only NVFP4, MXFP8 and WINT4_AFP8 supports outputting a different format as part of the expansion");
13931393

13941394
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1395-
asm volatile("griddepcontrol.wait;");
1395+
cudaGridDependencySynchronize();
13961396
#endif
13971397

13981398
constexpr int VecSize = is_nvfp4 ? TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize
@@ -1522,7 +1522,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp
15221522
}
15231523

15241524
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1525-
asm volatile("griddepcontrol.launch_dependents;");
1525+
cudaTriggerProgrammaticLaunchCompletion();
15261526
#endif
15271527

15281528
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF
@@ -1714,7 +1714,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
17141714
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
17151715

17161716
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1717-
asm volatile("griddepcontrol.wait;");
1717+
cudaGridDependencySynchronize();
17181718
#endif
17191719

17201720
#pragma unroll
@@ -1754,7 +1754,7 @@ __global__ void finalizeMoeRoutingKernel(GemmOutputType const* expanded_permuted
17541754
reduced_row_ptr_v[elem_index] = output_elem;
17551755
}
17561756
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1757-
asm volatile("griddepcontrol.launch_dependents;");
1757+
cudaTriggerProgrammaticLaunchCompletion();
17581758
#endif
17591759
}
17601760

@@ -1773,7 +1773,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded
17731773
assert(unpadded_cols <= padded_cols);
17741774

17751775
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1776-
asm volatile("griddepcontrol.wait;");
1776+
cudaGridDependencySynchronize();
17771777
#endif
17781778

17791779
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
@@ -1862,7 +1862,7 @@ __global__ void finalizeMoeRoutingNoFillingKernel(GemmOutputType const* expanded
18621862
}
18631863
}
18641864
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
1865-
asm volatile("griddepcontrol.launch_dependents;");
1865+
cudaTriggerProgrammaticLaunchCompletion();
18661866
#endif
18671867
}
18681868

@@ -2059,7 +2059,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
20592059
int64_t const num_valid_tokens = expert_first_token_offset[num_experts_per_node];
20602060

20612061
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2062-
asm volatile("griddepcontrol.wait;");
2062+
cudaGridDependencySynchronize();
20632063
#endif
20642064
for (int64_t token = blockIdx.x; token < num_valid_tokens; token += gridDim.x)
20652065
{
@@ -2175,7 +2175,7 @@ __global__ void doActivationKernel(T* output, GemmOutputType const* gemm_result,
21752175
}
21762176

21772177
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
2178-
asm volatile("griddepcontrol.launch_dependents;");
2178+
cudaTriggerProgrammaticLaunchCompletion();
21792179
#endif
21802180

21812181
// Pad zeros in the extra SFs along the N dimension, we do this to ensure there are no nan values in the padded SF

cpp/tensorrt_llm/kernels/groupRmsNormKernels/groupRmsNormKernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ __global__ void GroupRMSNormBaseKernel(GroupRMSParams<n> params, int rounds)
108108
PackedType const* __restrict__ weight_ptr = nullptr;
109109

110110
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
111-
asm volatile("griddepcontrol.wait;");
111+
cudaGridDependencySynchronize();
112112
#endif
113113

114114
// Find which input current warp operates on
@@ -260,7 +260,7 @@ __global__ void GroupRMSNormBaseKernel(GroupRMSParams<n> params, int rounds)
260260
}
261261

262262
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
263-
asm volatile("griddepcontrol.launch_dependents;");
263+
cudaTriggerProgrammaticLaunchCompletion();
264264
#endif
265265
}
266266

@@ -302,7 +302,7 @@ __global__ void GroupRMSNormKernelLargeBatch(
302302
bool process_input_1 = warp_idx < warp_size_1;
303303

304304
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
305-
asm volatile("griddepcontrol.wait;");
305+
cudaGridDependencySynchronize();
306306
#endif
307307

308308
// Get input pointers
@@ -562,7 +562,7 @@ __global__ void GroupRMSNormKernelLargeBatch(
562562
}
563563

564564
#if (__CUDACC_VER_MAJOR__ >= 12 && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
565-
asm volatile("griddepcontrol.launch_dependents;");
565+
cudaTriggerProgrammaticLaunchCompletion();
566566
#endif
567567
}
568568

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmAttnScalingPerBlockTemplate.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_attn_scaling_
247247
}
248248

249249
#if ENABLE_PREEXIT
250-
asm volatile("griddepcontrol.launch_dependents;");
250+
cudaTriggerProgrammaticLaunchCompletion();
251251
#endif
252252
#endif
253253
}

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmPerBlockTemplate.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_per_block_ker
234234
}
235235

236236
#if ENABLE_PREEXIT
237-
asm volatile("griddepcontrol.launch_dependents;");
237+
cudaTriggerProgrammaticLaunchCompletion();
238238
#endif
239239
#endif
240240
}

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16GemmPerWarpTemplate.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_bf16_gemm_per_warp_kern
260260
}
261261

262262
#if ENABLE_PREEXIT
263-
asm volatile("griddepcontrol.launch_dependents;");
263+
cudaTriggerProgrammaticLaunchCompletion();
264264
#endif
265265
#endif
266266
}

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Fp8GemmSwiGLUPerBlockTemplate.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void llama4_fp8_fp8_gemm_swiglu_per_blo
273273
}
274274

275275
#if ENABLE_PREEXIT
276-
asm volatile("griddepcontrol.launch_dependents;");
276+
cudaTriggerProgrammaticLaunchCompletion();
277277
#endif
278278
#endif
279279
}

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4MinLatencyMoEOp.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ __global__ void llama4_moe_fc13_swiglu_fp8_kernel(int num_tokens,
188188
}
189189

190190
#if ENABLE_PREEXIT
191-
asm volatile("griddepcontrol.launch_dependents;");
191+
cudaTriggerProgrammaticLaunchCompletion();
192192
#endif
193193
#endif
194194
}
@@ -307,7 +307,7 @@ __global__ void llama4_moe_fc2_fp8_kernel(int num_tokens,
307307
}
308308

309309
#if ENABLE_PREEXIT
310-
asm volatile("griddepcontrol.launch_dependents;");
310+
cudaTriggerProgrammaticLaunchCompletion();
311311
#endif
312312
#endif
313313
}

cpp/tensorrt_llm/kernels/mlaKernels.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
387387
// Block/Head idx.
388388
size_t const head_idx = blockIdx.y;
389389
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
390-
asm volatile("griddepcontrol.wait;");
390+
cudaGridDependencySynchronize();
391391
#endif
392392

393393
if (blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
@@ -595,7 +595,7 @@ __global__ void applyMLARopeAndAssignQKVKernelGeneration(T* qkv_output, T* q_pe,
595595
}
596596

597597
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
598-
asm volatile("griddepcontrol.launch_dependents;");
598+
cudaTriggerProgrammaticLaunchCompletion();
599599
#endif
600600

601601
// The implementation of the parallel scan in the thread block (see CUB for details).

0 commit comments

Comments
 (0)