Skip to content

Commit 86e45f8

Browse files
committed
update torch_ext API and debugging test for FusedAddRMSNorm
update #define for hopper & blackwell Signed-off-by: JtaoPeng <[email protected]>
1 parent dc32bac commit 86e45f8

File tree

8 files changed

+476
-31
lines changed

8 files changed

+476
-31
lines changed

cpp/tensorrt_llm/kernels/fusedLayernormKernels/low_latency_layernorm.cuh

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ struct LowLatencyLayerNorm
139139
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)
140140
{
141141
auto offset = (thread_id + i * N_THREADS) * Traits::PACKED_ELEMS_PER_COMPUTE;
142-
if (offset <= sz)
142+
if (offset < sz)
143143
{
144144
data[i] = *reinterpret_cast<PackedType const*>(&g_data[offset]);
145145
}
@@ -260,11 +260,11 @@ struct LowLatencyLayerNorm
260260
{
261261
mean = var_and_mean[1] / param.n;
262262
variance = rsqrtf(
263-
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(1e-5));
263+
var_and_mean[0] / param.n - var_and_mean[1] * var_and_mean[1] + (Traits::AccumulatorType)(param.layernorm_eps));
264264
}
265265
else
266266
{
267-
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(1e-5));
267+
variance = rsqrtf(var_and_mean[0] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
268268
}
269269

270270
for (int i = 0; i < PACKED_PER_N_BLOCK; i++)

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,11 @@ struct WarpSpecializedLayerNorm
201201
}
202202
// if (blockIdx.x == 0) printf("Pushed tile %d to MATH.\n", m_base);
203203

204+
const uint32_t eff_m_block
205+
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
204206
const auto tx
205-
= (Traits::M_BLOCK * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
206-
+ (FIRST_RUN ? sizeof(AuxData) / Traits::N_BLOCK * param.n : 0);
207+
= (eff_m_block * param.n * sizeof(typename Traits::InputType) * (Traits::RESIDUAL ? 2 : 1))
208+
+ (FIRST_RUN ? (sizeof(AuxData) / Traits::N_BLOCK * param.n) : 0);
207209

208210
auto vec_buffer_ptr = input_vec_fifo_w.tmaReserve(tx);
209211

@@ -216,10 +218,13 @@ struct WarpSpecializedLayerNorm
216218

217219
for (int i = 0; i < Traits::M_BLOCK; i++)
218220
{
219-
load_a_vec(&param.input[(m_base + i) * param.n],
220-
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
221-
param.n * sizeof(typename Traits::InputType),
222-
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
221+
if (i < eff_m_block) [[likely]]
222+
{
223+
load_a_vec(&param.input[(m_base + i) * param.n],
224+
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][0][i * Traits::N_BLOCK]),
225+
param.n * sizeof(typename Traits::InputType),
226+
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
227+
}
223228
}
224229

225230
// Use templated lambdas to defer resolving the symbols like "param.residual".
@@ -231,10 +236,13 @@ struct WarpSpecializedLayerNorm
231236
{
232237
for (int i = 0; i < Traits::M_BLOCK; i++)
233238
{
234-
load_a_vec(&param.residual[(m_base + i) * param.n],
235-
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
236-
param.n * sizeof(typename Traits::InputType),
237-
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
239+
if (i < eff_m_block) [[likely]]
240+
{
241+
load_a_vec(&param.residual[(m_base + i) * param.n],
242+
__nvvm_get_smem_pointer(&shared->input_vec[vec_buffer_ptr][1][i * Traits::N_BLOCK]),
243+
param.n * sizeof(typename Traits::InputType),
244+
__nvvm_get_smem_pointer(input_vec_fifo_w.barrier_ptr(vec_buffer_ptr)));
245+
}
238246
}
239247
}(param);
240248
}
@@ -446,6 +454,9 @@ struct WarpSpecializedLayerNorm
446454
{
447455
m_base = block_id;
448456
}
457+
const uint32_t eff_m_block
458+
= std::min(static_cast<uint32_t>(Traits::M_BLOCK), static_cast<uint32_t>(param.m - m_base));
459+
449460
// if (blockIdx.x == 0 && thread_id == 0) printf("MATH got tile %d.\n", m_base);
450461

451462
// Peek for data ready.
@@ -613,11 +624,11 @@ struct WarpSpecializedLayerNorm
613624
{
614625
mean[m_offset] /= param.n;
615626
variance[m_offset] = rsqrtf(variance[m_offset] / param.n - mean[m_offset] * mean[m_offset]
616-
+ (Traits::AccumulatorType)(1e-5));
627+
+ (Traits::AccumulatorType)(param.layernorm_eps));
617628
}
618629
else
619630
{
620-
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(1e-5));
631+
variance[m_offset] = rsqrtf(variance[m_offset] / param.n + (Traits::AccumulatorType)(param.layernorm_eps));
621632
}
622633
}
623634

@@ -660,7 +671,7 @@ struct WarpSpecializedLayerNorm
660671
}
661672

662673
#pragma unroll Traits::M_BLOCK
663-
for (int m_offset = 0; m_offset < Traits::M_BLOCK; m_offset++)
674+
for (int m_offset = 0; m_offset < eff_m_block; m_offset++)
664675
{
665676
auto m = m_base + m_offset;
666677

@@ -801,8 +812,7 @@ struct WarpSpecializedLayerNorm
801812
shared->init(threadIdx.x == 0);
802813

803814
__syncthreads();
804-
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
805-
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
815+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
806816
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
807817
{
808818
auto block_id = blockIdx.x;
@@ -830,7 +840,6 @@ struct WarpSpecializedLayerNorm
830840
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
831841
}
832842
}
833-
#endif
834843
#endif
835844
}
836845
};

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_library(
6666
fp8Quantize.cpp
6767
dsv3FusedAGemmOp.cpp
6868
fusedQKNormRopeOp.cpp
69+
fusedAddRMSNormQuant.cpp
6970
fusedTopkSoftmax.cpp
7071
gatherTreeOp.cpp
7172
groupRmsNormOp.cpp
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
/*
2+
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/cudaUtils.h"
18+
#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h"
19+
#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h"
20+
#include "tensorrt_llm/kernels/quantization.h"
21+
#include "tensorrt_llm/thop/thUtils.h"
22+
23+
#include <ATen/Functions.h>
24+
#include <ATen/cuda/CUDAContext.h>
25+
#include <ATen/cuda/EmptyTensor.h>
26+
27+
#include <cuda_bf16.h>
28+
#include <cuda_fp16.h>
29+
30+
#include <cstdint>
31+
#include <optional>
32+
#include <tuple>
33+
#include <unordered_map>
34+
35+
TRTLLM_NAMESPACE_BEGIN
36+
37+
namespace torch_ext
38+
{
39+
40+
// Fused Add + RMSNorm + FP4 Quantization kernel
41+
// input: [M, N] - input tensor (fp16/bf16)
42+
// residual: [M, N] - residual tensor (fp16/bf16)
43+
// gamma: [N] - RMSNorm weight (fp16/bf16)
44+
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
45+
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
46+
// Returns:
47+
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
48+
// output: [M, N] - pre-norm output (input + residual), same dtype as input
49+
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
50+
//
51+
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
52+
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
53+
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
54+
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm,
55+
double eps)
56+
{
57+
CHECK_TH_CUDA(input);
58+
CHECK_CONTIGUOUS(input);
59+
CHECK_TH_CUDA(residual);
60+
CHECK_CONTIGUOUS(residual);
61+
CHECK_TH_CUDA(gamma);
62+
CHECK_CONTIGUOUS(gamma);
63+
64+
// Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell)
65+
auto const device = input.get_device();
66+
cudaDeviceProp props;
67+
AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device));
68+
TORCH_CHECK(props.major >= 9,
69+
"fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. "
70+
"Current device: sm_",
71+
props.major, props.minor);
72+
73+
auto const& inputShape = input.sizes();
74+
auto const& rank = inputShape.size();
75+
76+
TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N].");
77+
TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape.");
78+
79+
int64_t const m = inputShape[0];
80+
int64_t const n = inputShape[1];
81+
// Some warp-specialized kernels may issue vectorized stores that assume M is padded.
82+
// Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8.
83+
int64_t const m_padded = ((m + 15) / 16) * 16;
84+
85+
TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N.");
86+
TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint).");
87+
TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384.");
88+
TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization.");
89+
90+
// Validate sf_scale if provided
91+
float* sfScalePtr = nullptr;
92+
if (sf_scale.has_value())
93+
{
94+
CHECK_INPUT(sf_scale.value(), torch::kFloat32);
95+
sfScalePtr = sf_scale.value().data_ptr<float>();
96+
}
97+
98+
// Allocate output tensors
99+
// normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32)
100+
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
101+
at::Tensor normed_output_padded
102+
= at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt);
103+
at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m);
104+
105+
// output: pre-norm output (input + residual) [M, N], same dtype as input
106+
// NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable.
107+
at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt);
108+
at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m);
109+
110+
// sf_out: scale factors for FP4, swizzled layout
111+
// sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor)
112+
int64_t const sfVecSize = 16;
113+
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize);
114+
at::Tensor sf_out = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt);
115+
116+
// Get number of SMs for persistent kernel
117+
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
118+
119+
// Allocate counters for warp-specialized kernel using PyTorch allocator.
120+
//
121+
// NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so
122+
// concurrent calls from different threads don't share the same counters buffer (which
123+
// could cause races across different CUDA streams).
124+
static thread_local std::unordered_map<int, at::Tensor> counters_tensor_cache;
125+
auto& counters_tensor = counters_tensor_cache[device];
126+
int64_t const counters_bytes = static_cast<int64_t>(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters));
127+
if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes)
128+
{
129+
counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt);
130+
counters_tensor.zero_();
131+
}
132+
auto* counters
133+
= reinterpret_cast<tensorrt_llm::kernels::WarpSpecializedCounters*>(counters_tensor.mutable_data_ptr());
134+
135+
auto stream = at::cuda::getCurrentCUDAStream(device);
136+
137+
#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \
138+
do \
139+
{ \
140+
using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam<T>; \
141+
tensorrt_llm::kernels::WarpSpecializedParam<Param> param; \
142+
param.normed_output = reinterpret_cast<uint32_t*>(normed_output.data_ptr()); \
143+
param.output = reinterpret_cast<T*>(output.data_ptr()); \
144+
param.input = const_cast<T*>(reinterpret_cast<T const*>(input.data_ptr())); \
145+
param.sf_scale = sfScalePtr; \
146+
param.sf_out = reinterpret_cast<uint32_t*>(sf_out.data_ptr()); \
147+
param.residual = reinterpret_cast<T const*>(residual.data_ptr()); \
148+
param.bias = nullptr; \
149+
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
150+
param.beta = nullptr; \
151+
param.m = static_cast<int>(m); \
152+
param.n = static_cast<int>(n); \
153+
param.layernorm_eps = static_cast<float>(eps); \
154+
param.stream = stream; \
155+
param.counters = counters; \
156+
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
157+
} while (0)
158+
159+
if (input.scalar_type() == at::ScalarType::Half)
160+
{
161+
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half);
162+
}
163+
else if (input.scalar_type() == at::ScalarType::BFloat16)
164+
{
165+
#ifdef ENABLE_BF16
166+
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16);
167+
#else
168+
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input.");
169+
#endif
170+
}
171+
else
172+
{
173+
C10_THROW_ERROR(
174+
NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16.");
175+
}
176+
177+
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
178+
179+
// No explicit sync needed - kernel runs asynchronously on the stream
180+
return std::make_tuple(normed_output, output, sf_out);
181+
}
182+
183+
} // namespace torch_ext
184+
185+
TRTLLM_NAMESPACE_END
186+
187+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
188+
{
189+
m.def(
190+
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
191+
"Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-5) -> (Tensor, Tensor, Tensor)");
192+
}
193+
194+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
195+
{
196+
m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant);
197+
}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1869,3 +1869,56 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
18691869
stream = get_stream(stream_id)
18701870
assert stream is not None
18711871
tensor.record_stream(stream)
1872+
1873+
1874+
def fused_add_rms_norm_quant(
1875+
input: torch.Tensor,
1876+
residual: torch.Tensor,
1877+
gamma: torch.Tensor,
1878+
sf_scale: Optional[torch.Tensor],
1879+
use_rms_norm: bool = True,
1880+
eps: float = 1e-5,
1881+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1882+
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
1883+
1884+
Args:
1885+
input: [M, N] input tensor (fp16/bf16)
1886+
residual: [M, N] residual tensor (fp16/bf16)
1887+
gamma: [N] normalization weight (fp16/bf16)
1888+
sf_scale: [1] optional scale factor for FP4 quantization (float32)
1889+
use_rms_norm: if True use RMSNorm, else use LayerNorm
1890+
eps: epsilon for normalization
1891+
1892+
Returns:
1893+
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
1894+
output: [M, N] pre-norm output (input + residual), same dtype as input
1895+
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
1896+
1897+
Note:
1898+
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
1899+
Hidden dimension N must be >= 2048 and <= 16384.
1900+
"""
1901+
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
1902+
sf_scale, use_rms_norm,
1903+
eps)
1904+
1905+
1906+
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
1907+
def _fused_add_rms_norm_quant_fake(
1908+
input: torch.Tensor,
1909+
residual: torch.Tensor,
1910+
gamma: torch.Tensor,
1911+
sf_scale: Optional[torch.Tensor],
1912+
use_rms_norm: bool = True,
1913+
eps: float = 1e-5,
1914+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1915+
m, n = input.shape
1916+
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
1917+
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
1918+
# output: [M, N] pre-norm output, same dtype as input
1919+
output = input.new_empty((m, n), dtype=input.dtype)
1920+
# sf_out: scale factors, swizzled layout
1921+
sf_vec_size = 16
1922+
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
1923+
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
1924+
return normed_output_fp4, output, sf_out

0 commit comments

Comments
 (0)