Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,6 @@ struct LowLatencyLayerNorm

uint32_t work_id = blockIdx.x;

FusedOperator fused_operator(param);

constexpr auto PACKED_PER_N_BLOCK = Traits::N_BLOCK / N_THREADS / Traits::PACKED_ELEMS_PER_COMPUTE;

typename Traits::AccumulatorType data[PACKED_PER_N_BLOCK][Traits::PACKED_ELEMS_PER_COMPUTE];
Expand All @@ -139,7 +137,7 @@ struct LowLatencyLayerNorm
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
{
auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
if (offset <= sz)
if (offset < sz)
{
data[i] = *reinterpret_cast<PackedType const*>(&g_data[offset]);
}
Expand All @@ -155,6 +153,14 @@ struct LowLatencyLayerNorm

static_assert(Traits::OUTPUT_SCALE != SCALE_TYPE::VECTOR);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaGridDependencySynchronize();
}
#endif
FusedOperator fused_operator(param);

if constexpr (Traits::BIAS == SCALE_TYPE::VECTOR)
{
load_to_register(param.bias, r_bias, param.n);
Expand All @@ -175,13 +181,6 @@ struct LowLatencyLayerNorm
load_to_register(param.beta, r_beta, param.n);
}

#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaGridDependencySynchronize();
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
load_to_register(&param.input[work_id * param.n], data, param.n);

if constexpr (Traits::RESIDUAL)
Expand Down Expand Up @@ -259,12 +258,12 @@ struct LowLatencyLayerNorm
if constexpr (!Traits::RMS_NORM)
{
mean = var_and_mean[1] / param.n;
variance = rsqrtf(
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5));
variance = rsqrtf(var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1]
+ (Traits::AccumulatorType)(param.layernorm_eps));
}
else
{
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5));
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
}

for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
Expand Down Expand Up @@ -333,6 +332,14 @@ struct LowLatencyLayerNorm
{
__shared__ Shared shared;
compute(param, &shared);
__syncthreads();
asm volatile("membar.gl;" : : : "memory");
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
}
};

Expand Down
71 changes: 46 additions & 25 deletions cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,35 @@ struct WarpSpecializedLayerNorm
}
// if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);

if constexpr (FIRST_RUN)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
cudaGridDependencySynchronize();
}
#endif
}
const uint32_t eff_m_block
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
const auto tx
= (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
+ (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0);
= (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
+ (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0);

auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx);

// if (blockIdx.x == 0) printf("SMEM buffer ready, start loading tile %d.\n", m_base);

if constexpr (FIRST_RUN)
{
cudaGridDependencySynchronize();
}

for (int i = 0; i < Traits::M_BLOCK; i++)
{
load_a_vec(&param.input[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
if (i < eff_m_block) [[likely]]
{
load_a_vec(&param.input[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
}
}

// Use templated lambdas to defer resolving the symbols like "param.residual".
Expand All @@ -231,10 +241,13 @@ struct WarpSpecializedLayerNorm
{
for (int i = 0; i < Traits::M_BLOCK; i++)
{
load_a_vec(&param.residual[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
if (i < eff_m_block) [[likely]]
{
load_a_vec(&param.residual[(m_base + i) * param.n],
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
param.n * sizeof(typename Traits::InputType),
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
}
}
}(param);
}
Expand Down Expand Up @@ -423,6 +436,13 @@ struct WarpSpecializedLayerNorm

using FusedOperator = GetFusedOperator<typename Traits::FusedOperator>;

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
// Ensure upstream kernel writes are visible before reading dependent activation/residual data.
cudaGridDependencySynchronize();
}
#endif
FusedOperator fused_operator(param);

static_assert(Traits::PERSISTENT_MODE || Traits::MATH_WARPGROUPS == 1);
Expand All @@ -446,6 +466,9 @@ struct WarpSpecializedLayerNorm
{
m_base = block_id;
}
const uint32_t eff_m_block
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));

// if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);

// Peek for data ready.
Expand Down Expand Up @@ -613,11 +636,12 @@ struct WarpSpecializedLayerNorm
{
mean[m_offset] /= param.n;
variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
+ (Traits::AccumulatorType)(1e-5));
+ (Traits::AccumulatorType)(param.layernorm_eps));
}
else
{
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5));
variance[m_offset]
= rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
}
}

Expand Down Expand Up @@ -659,8 +683,7 @@ struct WarpSpecializedLayerNorm
}
}

#pragma unroll Traits::M_BLOCK
for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++)
for (int m_offset = 0; m_offset < eff_m_block; m_offset++)
{
auto m = m_base + m_offset;

Expand Down Expand Up @@ -801,23 +824,19 @@ struct WarpSpecializedLayerNorm
shared->init(threadIdx.x == 0);

__syncthreads();
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
{
auto block_id = blockIdx.x;
auto warp_id = threadIdx.x / 32;
auto lane_id = threadIdx.x % 32;
auto tid_in_wg = threadIdx.x % 128;

if (warp_id < 4)
{
asm volatile("{setmaxnreg.dec.sync.aligned.u32 56; \n\t}");
if (warp_id == 0)
{
scheduler(lane_id, gridDim.x * gridDim.y * gridDim.z, param, shared);
// PRE-EXIT after all tiles have been scheduled.
cudaTriggerProgrammaticLaunchCompletion();
}
else if (warp_id == 1)
{
Expand All @@ -829,8 +848,10 @@ struct WarpSpecializedLayerNorm
asm volatile("{setmaxnreg.inc.sync.aligned.u32 224; \n\t}");
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
}
__syncthreads();
asm volatile("membar.gl;" : : : "memory");
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
#endif
}
};
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/thop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ add_library(
fp8Quantize.cpp
dsv3FusedAGemmOp.cpp
fusedQKNormRopeOp.cpp
fusedAddRMSNormQuant.cpp
fusedTopkSoftmax.cpp
gatherTreeOp.cpp
groupRmsNormOp.cpp
Expand Down
Loading