Skip to content

Commit acff6ae

Browse files
committed
update torch_ext API and debugging test for FusedAddRMSNorm
update #define for hopper & blackwell Signed-off-by: JtaoPeng <jintaop@nvidia.com>
1 parent 4c498bf commit acff6ae

File tree

8 files changed

+525
-48
lines changed

8 files changed

+525
-48
lines changed

cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,6 @@ struct LowLatencyLayerNorm
115115

116116
uint32_t work_id = blockIdx.x;
117117

118-
FusedOperator fused_operator(param);
119-
120118
constexpr auto PACKED_PER_N_BLOCK = Traits::N_BLOCK / N_THREADS / Traits::PACKED_ELEMS_PER_COMPUTE;
121119

122120
typename Traits::AccumulatorType data[PACKED_PER_N_BLOCK][Traits::PACKED_ELEMS_PER_COMPUTE];
@@ -139,7 +137,7 @@ struct LowLatencyLayerNorm
139137
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
140138
{
141139
auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
142-
if (offset <= sz)
140+
if (offset < sz)
143141
{
144142
data[i] = *reinterpret_cast<PackedType const*>(&g_data[offset]);
145143
}
@@ -155,6 +153,17 @@ struct LowLatencyLayerNorm
155153

156154
static_assert(Traits::OUTPUT_SCALE != SCALE_TYPE::VECTOR);
157155

156+
157+
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
158+
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
159+
{
160+
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
161+
cudaGridDependencySynchronize();
162+
// cudaTriggerProgrammaticLaunchCompletion();
163+
}
164+
#endif
165+
FusedOperator fused_operator(param);
166+
158167
if constexpr (Traits::BIAS == SCALE_TYPE::VECTOR)
159168
{
160169
load_to_register(param.bias, r_bias, param.n);
@@ -175,13 +184,6 @@ struct LowLatencyLayerNorm
175184
load_to_register(param.beta, r_beta, param.n);
176185
}
177186

178-
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
179-
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
180-
{
181-
cudaGridDependencySynchronize();
182-
cudaTriggerProgrammaticLaunchCompletion();
183-
}
184-
#endif
185187
load_to_register(&param.input[work_id * param.n], data, param.n);
186188

187189
if constexpr (Traits::RESIDUAL)
@@ -260,11 +262,11 @@ struct LowLatencyLayerNorm
260262
{
261263
mean = var_and_mean[1] / param.n;
262264
variance = rsqrtf(
263-
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5));
265+
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(param.layernorm_eps));
264266
}
265267
else
266268
{
267-
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5));
269+
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
268270
}
269271

270272
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
@@ -332,7 +334,11 @@ struct LowLatencyLayerNorm
332334
static __device__ void run(const Param param)
333335
{
334336
__shared__ Shared shared;
337+
// cudaGridDependencySynchronize();
335338
compute(param, &shared);
339+
__syncthreads();
340+
asm volatile("membar.gl;" : : : "memory");
341+
cudaTriggerProgrammaticLaunchCompletion();
336342
}
337343
};
338344

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,12 @@ struct WarpSpecializedLayerNorm
139139
scheduled_tiles++;
140140
// if (blockIdx.x == 0) printf("Pushed tile %d to DMA.\n", tile_id);
141141
}
142+
// #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
143+
// if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
144+
// {
145+
// cudaTriggerProgrammaticLaunchCompletion();
146+
// }
147+
// #endif
142148
sched2dma_w.push(0xffffffff);
143149
// if (blockIdx.x == 0) printf("Pushed tile -1 to DMA.\n");
144150
if (atomicAdd(&(param.counters->cta_completion_ctr), 1) == grid_sz - 1)
@@ -151,6 +157,12 @@ struct WarpSpecializedLayerNorm
151157
else
152158
{
153159
scheduled_tiles = 1;
160+
// #if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
161+
// if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
162+
// {
163+
// cudaTriggerProgrammaticLaunchCompletion();
164+
// }
165+
// #endif
154166
}
155167
return scheduled_tiles;
156168
}
@@ -201,25 +213,30 @@ struct WarpSpecializedLayerNorm
201213
}
202214
// if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);
203215

216+
if constexpr (FIRST_RUN)
217+
{
218+
cudaGridDependencySynchronize();
219+
}
220+
const uint32_t eff_m_block
221+
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
204222
const auto tx
205-
= (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
206-
+ (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0);
223+
= (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
224+
+ (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0);
207225

208226
auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx);
209227

210228
// if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base);
211229

212-
if constexpr (FIRST_RUN)
213-
{
214-
cudaGridDependencySynchronize();
215-
}
216230

217231
for (int i = 0; i < Traits::M_BLOCK; i++)
218232
{
219-
load_a_vec(&param.input[(m_base + i) * param.n],
220-
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
221-
param.n * sizeof(typename Traits::InputType),
222-
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
233+
if (i < eff_m_block) [[likely]]
234+
{
235+
load_a_vec(&param.input[(m_base + i) * param.n],
236+
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
237+
param.n * sizeof(typename Traits::InputType),
238+
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
239+
}
223240
}
224241

225242
// Use templated lambdas to defer resolving the symbols like "param.residual".
@@ -231,10 +248,13 @@ struct WarpSpecializedLayerNorm
231248
{
232249
for (int i = 0; i < Traits::M_BLOCK; i++)
233250
{
234-
load_a_vec(&param.residual[(m_base + i) * param.n],
235-
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
236-
param.n * sizeof(typename Traits::InputType),
237-
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
251+
if (i < eff_m_block) [[likely]]
252+
{
253+
load_a_vec(&param.residual[(m_base + i) * param.n],
254+
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
255+
param.n * sizeof(typename Traits::InputType),
256+
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
257+
}
238258
}
239259
}(param);
240260
}
@@ -423,6 +443,13 @@ struct WarpSpecializedLayerNorm
423443

424444
using FusedOperator = GetFusedOperator<typename Traits::FusedOperator>;
425445

446+
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
447+
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
448+
{
449+
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
450+
cudaGridDependencySynchronize();
451+
}
452+
#endif
426453
FusedOperator fused_operator(param);
427454

428455
static_assert(Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1);
@@ -446,6 +473,9 @@ struct WarpSpecializedLayerNorm
446473
{
447474
m_base = block_id;
448475
}
476+
const uint32_t eff_m_block
477+
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
478+
449479
// if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);
450480

451481
// Peek for data ready.
@@ -613,11 +643,11 @@ struct WarpSpecializedLayerNorm
613643
{
614644
mean[m_offset] /= param.n;
615645
variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
616-
+ (Traits::AccumulatorType)(1e-5));
646+
+ (Traits::AccumulatorType)(param.layernorm_eps));
617647
}
618648
else
619649
{
620-
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5));
650+
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
621651
}
622652
}
623653

@@ -659,8 +689,7 @@ struct WarpSpecializedLayerNorm
659689
}
660690
}
661691

662-
#pragma unroll Traits::M_BLOCK
663-
for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++)
692+
for (int m_offset = 0; m_offset < eff_m_block; m_offset++)
664693
{
665694
auto m = m_base + m_offset;
666695

@@ -801,23 +830,22 @@ struct WarpSpecializedLayerNorm
801830
shared->init(threadIdx.x == 0);
802831

803832
__syncthreads();
804-
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
805-
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
833+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
806834
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
807835
{
808836
auto block_id = blockIdx.x;
809837
auto warp_id = threadIdx.x / 32;
810838
auto lane_id = threadIdx.x % 32;
811839
auto tid_in_wg = threadIdx.x % 128;
812-
840+
// cudaGridDependencySynchronize();
813841
if (warp_id < 4)
814842
{
815843
asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
816844
if (warp_id == 0)
817845
{
818846
scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared);
819-
// PRE-EXIT after all tiles have been scheduled.
820-
cudaTriggerProgrammaticLaunchCompletion();
847+
// PRE-EXIT after all tiles have been scheduled.
848+
// cudaTriggerProgrammaticLaunchCompletion();
821849
}
822850
else if (warp_id == 1)
823851
{
@@ -829,8 +857,11 @@ struct WarpSpecializedLayerNorm
829857
asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
830858
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
831859
}
860+
__syncthreads();
861+
asm volatile("membar.gl;" : : : "memory");
862+
cudaTriggerProgrammaticLaunchCompletion();
863+
// cudaTriggerProgrammaticLaunchCompletion();
832864
}
833-
#endif
834865
#endif
835866
}
836867
};

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_library(
6666
fp8Quantize.cpp
6767
dsv3FusedAGemmOp.cpp
6868
fusedQKNormRopeOp.cpp
69+
fusedAddRMSNormQuant.cpp
6970
fusedTopkSoftmax.cpp
7071
gatherTreeOp.cpp
7172
groupRmsNormOp.cpp

0 commit comments

Comments
 (0)