Skip to content

Commit 90aab05

Browse files
authored
[fix] Fix Llama4 guradwords failures (#4844)
Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com>
1 parent 13f6833 commit 90aab05

File tree

6 files changed

+29
-29
lines changed

6 files changed

+29
-29
lines changed

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Bf16Bf16Gemm.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ __global__ void llama4_bf16_bf16_gemm_kernel(int num_tokens,
4545
int const row = blockIdx.x % NUM_EXPERTS; // Matrix row / Output element index
4646
int const tid = threadIdx.x; // Thread ID within the block
4747

48-
// FDL prefetch all B data
48+
// PDL prefetch all B data
4949
aligned_bf16x4 b_vec[GEMM_K / BLOCK_SIZE / VEC_SIZE];
5050
#pragma unroll
5151
for (int chunk = 0; chunk < GEMM_K / BLOCK_SIZE / VEC_SIZE; chunk++)
@@ -113,7 +113,7 @@ void llama4_bf16_bf16_gemm_launcher(
113113
int const grid_size = NUM_EXPERTS * num_tokens;
114114

115115
void* args[] = {(void*) &num_tokens, (void*) &A, (void*) &B, (void*) &C};
116-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_bf16_bf16_gemm_kernel, args, 4);
116+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_bf16_bf16_gemm_kernel, args, 4);
117117
}
118118

119119
void llama4_bf16_bf16_gemm_op(int num_tokens, void const* A, void const* B, void* C, cudaStream_t stream)

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Bf16Gemm.cu

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,56 +38,56 @@ void llama4_fp8_bf16_gemm_launcher(void const* A, void const* B, void* C, void c
3838
// When num_tokens == 1, the best tiling size is tile_token == 1 and tile_out == 1.
3939
dim3 const grid_size = dim3(div_up(hidden_out, 1), div_up(num_tokens, 1), 1);
4040
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(1, 1);
41-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
41+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
4242
}
4343
else if (num_tokens == 2)
4444
{
4545
// When num_tokens == 2, the best tiling size is tile_token == 2 and tile_out == 1.
4646
dim3 const grid_size = dim3(div_up(hidden_out, 1), div_up(num_tokens, 2), 1);
4747
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(2, 1);
48-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
48+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
4949
}
5050
else if (num_tokens == 3)
5151
{
5252
// When num_tokens == 3, the best tiling size is tile_token == 1 and tile_out == 4.
5353
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
5454
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(1, 4);
55-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
55+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
5656
}
5757
else if (num_tokens == 4)
5858
{
5959
// When num_tokens == 4, the best tiling size is tile_token == 2 and tile_out == 2.
6060
dim3 const grid_size = dim3(div_up(hidden_out, 2), div_up(num_tokens, 2), 1);
6161
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(2, 2);
62-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
62+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
6363
}
6464
else if (num_tokens == 5)
6565
{
6666
// When num_tokens == 5, the best tiling size is tile_token == 1 and tile_out == 4.
6767
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
6868
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(1, 4);
69-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
69+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
7070
}
7171
else if (num_tokens == 6)
7272
{
7373
// When num_tokens == 6, the best tiling size is tile_token == 3 and tile_out == 4.
7474
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 3), 1);
7575
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(3, 4);
76-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
76+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
7777
}
7878
else if (num_tokens == 7)
7979
{
8080
// When num_tokens == 7, the best tiling size is tile_token == 1 and tile_out == 4.
8181
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
8282
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(1, 4);
83-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
83+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
8484
}
8585
else if (num_tokens == 8)
8686
{
8787
// When num_tokens == 8, the best tiling size is tile_token == 2 and tile_out == 4.
8888
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 2), 1);
8989
void* kernel_func = get_per_block_func_ptr_aligned_true_5120_(2, 4);
90-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
90+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 7);
9191
}
9292
else
9393
{
@@ -115,56 +115,56 @@ void llama4_fp8_bf16_gemm_attn_scaling_launcher(void const* A, void const* B, vo
115115
// When num_tokens == 1, the best tiling size is tile_token == 1 and tile_out == 1.
116116
dim3 const grid_size = dim3(div_up(hidden_out, 1), div_up(num_tokens, 1), 1);
117117
void* kernel_func = get_kernel_func(1, 1, pos_ids_int64);
118-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
118+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
119119
}
120120
else if (num_tokens == 2)
121121
{
122122
// When num_tokens == 2, the best tiling size is tile_token == 2 and tile_out == 2.
123123
dim3 const grid_size = dim3(div_up(hidden_out, 2), div_up(num_tokens, 2), 1);
124124
void* kernel_func = get_kernel_func(2, 2, pos_ids_int64);
125-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
125+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
126126
}
127127
else if (num_tokens == 3)
128128
{
129129
// When num_tokens == 3, the best tiling size is tile_token == 1 and tile_out == 4.
130130
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
131131
void* kernel_func = get_kernel_func(1, 4, pos_ids_int64);
132-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
132+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
133133
}
134134
else if (num_tokens == 4)
135135
{
136136
// When num_tokens == 4, the best tiling size is tile_token == 2 and tile_out == 2.
137137
dim3 const grid_size = dim3(div_up(hidden_out, 2), div_up(num_tokens, 2), 1);
138138
void* kernel_func = get_kernel_func(2, 2, pos_ids_int64);
139-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
139+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
140140
}
141141
else if (num_tokens == 5)
142142
{
143143
// When num_tokens == 5, the best tiling size is tile_token == 1 and tile_out == 4.
144144
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
145145
void* kernel_func = get_kernel_func(1, 4, pos_ids_int64);
146-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
146+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
147147
}
148148
else if (num_tokens == 6)
149149
{
150150
// When num_tokens == 6, the best tiling size is tile_token == 2 and tile_out == 4.
151151
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 2), 1);
152152
void* kernel_func = get_kernel_func(2, 4, pos_ids_int64);
153-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
153+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
154154
}
155155
else if (num_tokens == 7)
156156
{
157157
// When num_tokens == 7, the best tiling size is tile_token == 1 and tile_out == 4.
158158
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 1), 1);
159159
void* kernel_func = get_kernel_func(1, 4, pos_ids_int64);
160-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
160+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
161161
}
162162
else if (num_tokens == 8)
163163
{
164164
// When num_tokens == 8, the best tiling size is tile_token == 2 and tile_out == 4.
165165
dim3 const grid_size = dim3(div_up(hidden_out, 4), div_up(num_tokens, 2), 1);
166166
void* kernel_func = get_kernel_func(2, 4, pos_ids_int64);
167-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
167+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, kernel_func, args, 11);
168168
}
169169
else
170170
{

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Fp8Fp8GemmSwiGLU.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ void dispatch_llama4_fp8_fp8_gemm_swiglu_hidden_in(void const* __restrict__ A, v
8080

8181
void* args[] = {(void*) &A, (void*) &B, (void*) &C, (void*) &in_scale, (void*) &out_scale_inv, (void*) &num_tokens,
8282
(void*) &hidden_in, (void*) &hidden_out};
83-
launch_kernel_fdl(grid_size, dim3(BLOCK_SIZE), stream, func_ptr, args, 8);
83+
launch_kernel_pdl(grid_size, dim3(BLOCK_SIZE), stream, func_ptr, args, 8);
8484
}
8585

8686
template <int TILE_TOKEN>

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4MinLatencyMoEOp.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ namespace tensorrt_llm::kernels::llama4_min_latency::llama4_moe
3939
#define TOPK_VEC_SIZE 4
4040
static_assert(NUM_EXPERTS == TOPK_VEC_SIZE * WARP_SIZE, "NUM_EXPERTS must be equal to TOPK_VEC_SIZE * WARP_SIZE");
4141

42-
// This is the hand-optimized kernel by Po-Han.
42+
// This is the hand-optimized kernel.
4343
// The computation is:
4444
// C = silu(AxB_gated * in_scale * sigmoid(logit)) * (AxB_linear * in_scale * sigmoid(logit)) * out_scale_inv
4545
// The out_scale_inv cannot be fused with in_scale because silu() is non-linear.
@@ -213,10 +213,10 @@ void launch_llama4_moe_fc13_swiglu_fp8_kernel(int num_tokens, int num_experts,
213213

214214
void* args[] = {(void*) &num_tokens, (void*) &A, (void*) &B, (void*) &logits, (void*) &C, (void*) &exp_idx,
215215
(void*) &in_scales, (void*) &out_scale_inv};
216-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_moe_fc13_swiglu_fp8_kernel, args, 8);
216+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_moe_fc13_swiglu_fp8_kernel, args, 8);
217217
}
218218

219-
// This is the hand-optimized kernel by Po-Han.
219+
// This is the hand-optimized kernel.
220220
__global__ void llama4_moe_fc2_fp8_kernel(int num_tokens,
221221
__nv_fp8_e4m3 const* __restrict__ A, // Input tensor A [num_tokens][INTER_SIZE]
222222
__nv_fp8_e4m3 const* __restrict__ B, // Input tensor B [num_experts][HIDDEN_SIZE][INTER_SIZE]
@@ -329,7 +329,7 @@ void launch_llama4_moe_fc2_fp8_kernel(int num_tokens, int num_experts,
329329

330330
void* args[]
331331
= {(void*) &num_tokens, (void*) &A, (void*) &B, (void*) &exp_idx, (void*) &C, (void*) &scaling_factors};
332-
launch_kernel_fdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_moe_fc2_fp8_kernel, args, 6);
332+
launch_kernel_pdl(dim3(grid_size), dim3(BLOCK_SIZE), stream, (void*) llama4_moe_fc2_fp8_kernel, args, 6);
333333
}
334334

335335
void run_moe_llama4_tp8ep1_min_latency(int num_tokens, int num_experts,

cpp/tensorrt_llm/kernels/llama4MinLatencyKernels/llama4Utils.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ constexpr bool ENABLE_PREEXIT = 0;
6666

6767
} // namespace llama4_fp8_fp8_gemm_swiglu
6868

69-
inline void launch_kernel_fdl(
69+
inline void launch_kernel_pdl(
7070
dim3 grid_dim, dim3 block_dim, cudaStream_t stream, void* kernel_func, void* args[], int num_args)
7171
{
7272
cudaLaunchConfig_t config;

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1774,7 +1774,7 @@ __global__ void __launch_bounds__(WarpSize) routingIndicesWarpKernel(KernelParam
17741774
params.mPtrNumNonExitingCtas[0] = numNonExitingCtas;
17751775
}
17761776

1777-
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
1777+
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
17781778
// we can trigger the next kernel at this point
17791779
if constexpr (KernelParams::UsePdl)
17801780
{
@@ -2059,7 +2059,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
20592059
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
20602060
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
20612061
// TODO: this is not sufficient to ensure visibility in the next kernel!
2062-
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
2062+
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
20632063
if constexpr (KernelParams::UsePdl)
20642064
{
20652065
cudaTriggerProgrammaticLaunchCompletion();
@@ -2517,7 +2517,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
25172517
// Trigger secondary kernel.
25182518
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
25192519
// dependency sync.
2520-
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
2520+
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
25212521
if constexpr (KernelParams::UsePdl)
25222522
{
25232523
cudaTriggerProgrammaticLaunchCompletion();
@@ -3183,7 +3183,7 @@ __global__ void __cluster_dims__(NumBlocksPerCluster, 1, 1) __launch_bounds__(Nu
31833183
// We can't do it earlier because FC1 depends on the mPtrCtaIdxXyToBatchIdx,
31843184
// mPtrCtaIdxXyToMnLimit, mPtrNumNonExitingCtas and mPtrTotalNumPaddedTokens
31853185
// TODO: this is not sufficient to ensure visibility in the next kernel!
3186-
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
3186+
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
31873187
if constexpr (KernelParams::UsePdl)
31883188
{
31893189
cudaTriggerProgrammaticLaunchCompletion();
@@ -3665,7 +3665,7 @@ __global__ void __launch_bounds__(NumThreadsHist) routingIndicesOffsetsKernel(Ke
36653665
// Trigger secondary kernel.
36663666
// Note: this does not guarantee the visibility of prior writes unless the consumer executes a
36673667
// dependency sync.
3668-
#if !defined(FDL_PROFILE) || FDL_PROFILE == 0
3668+
#if !defined(PDL_PROFILE) || PDL_PROFILE == 0
36693669
if constexpr (KernelParams::UsePdl)
36703670
{
36713671
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))

0 commit comments

Comments
 (0)