Skip to content

Commit e421f96

Browse files
committed
update torch_ext API and debugging test for FusedAddRMSNorm
update #define for hopper & blackwell Signed-off-by: JtaoPeng <[email protected]>
1 parent 9d1f2a9 commit e421f96

File tree

4 files changed

+228
-3
lines changed

4 files changed

+228
-3
lines changed

cpp/tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.cuh

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,8 +798,7 @@ struct WarpSpecializedLayerNorm
798798
shared->init(threadIdx.x == 0);
799799

800800
__syncthreads();
801-
#if (defined(__CUDA_ARCH__) && (__CUDACC_VER_MAJOR__ >= 12))
802-
#if (defined(__CUDA_ARCH_FEAT_SM90_ALL) || defined(__CUDA_ARCH_FEAT_SM100_ALL))
801+
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)
803802
if constexpr (arch::is_major_v<9> || arch::is_major_v<10>)
804803
{
805804
auto block_id = blockIdx.x;
@@ -827,7 +826,6 @@ struct WarpSpecializedLayerNorm
827826
compute(block_id, threadIdx.x / 128 - 1, tid_in_wg, param, shared);
828827
}
829828
}
830-
#endif
831829
#endif
832830
}
833831
};

cpp/tensorrt_llm/thop/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ add_library(
6666
fp8Quantize.cpp
6767
dsv3FusedAGemmOp.cpp
6868
fusedQKNormRopeOp.cpp
69+
fusedAddRMSNormQuant.cpp
6970
fusedTopkSoftmax.cpp
7071
gatherTreeOp.cpp
7172
groupRmsNormOp.cpp
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
/*
2+
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#include "tensorrt_llm/common/cudaUtils.h"
18+
#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h"
19+
#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h"
20+
#include "tensorrt_llm/kernels/quantization.h"
21+
#include "tensorrt_llm/thop/thUtils.h"
22+
23+
#include <ATen/cuda/CUDAContext.h>
24+
#include <ATen/cuda/EmptyTensor.h>
25+
26+
#include <cuda_bf16.h>
27+
#include <cuda_fp16.h>
28+
29+
#include <cstdint>
30+
#include <optional>
31+
#include <tuple>
32+
33+
namespace torch_ext
34+
{
35+
36+
// Fused Add + RMSNorm + FP4 Quantization kernel
37+
// input: [M, N] - input tensor (fp16/bf16)
38+
// residual: [M, N] - residual tensor (fp16/bf16)
39+
// gamma: [N] - RMSNorm weight (fp16/bf16)
40+
// sf_scale: [1] - optional scale factor for FP4 quantization (float)
41+
// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm
42+
// Returns:
43+
// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed)
44+
// output: [M, N] - pre-norm output (input + residual), same dtype as input
45+
// sf_out: scale factors for FP4 (uint8_t), swizzled layout
46+
//
47+
// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture.
48+
// NOTE: Hidden dimension N must be >= 2048 and <= 16384.
49+
std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input,
50+
at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm)
51+
{
52+
CHECK_TH_CUDA(input);
53+
CHECK_CONTIGUOUS(input);
54+
CHECK_TH_CUDA(residual);
55+
CHECK_CONTIGUOUS(residual);
56+
CHECK_TH_CUDA(gamma);
57+
CHECK_CONTIGUOUS(gamma);
58+
59+
// Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell)
60+
auto const device = input.get_device();
61+
cudaDeviceProp props;
62+
AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device));
63+
TORCH_CHECK(props.major >= 9,
64+
"fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. "
65+
"Current device: sm_",
66+
props.major, props.minor);
67+
68+
auto const& inputShape = input.sizes();
69+
auto const& rank = inputShape.size();
70+
71+
TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N].");
72+
TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape.");
73+
74+
int64_t const m = inputShape[0];
75+
int64_t const n = inputShape[1];
76+
77+
TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N.");
78+
TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint).");
79+
TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384.");
80+
TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization.");
81+
82+
// Validate sf_scale if provided
83+
float* sfScalePtr = nullptr;
84+
if (sf_scale.has_value())
85+
{
86+
CHECK_INPUT(sf_scale.value(), torch::kFloat32);
87+
sfScalePtr = sf_scale.value().data_ptr<float>();
88+
}
89+
90+
// Allocate output tensors
91+
// normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32)
92+
auto normed_output = at::detail::empty_cuda({m, n / 8}, torch::kInt32, input.device(), std::nullopt);
93+
94+
// output: pre-norm output (input + residual) [M, N], same dtype as input
95+
auto output = at::detail::empty_cuda({m, n}, input.scalar_type(), input.device(), std::nullopt);
96+
97+
// sf_out: scale factors for FP4, swizzled layout
98+
// sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor)
99+
int64_t const sfVecSize = 16;
100+
int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize);
101+
auto sf_out = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt);
102+
103+
// Get number of SMs for persistent kernel
104+
static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount();
105+
106+
// Allocate counters for warp-specialized kernel using PyTorch allocator
107+
// This avoids cudaMalloc/cudaFree overhead and ensures proper cleanup
108+
auto counters_tensor
109+
= at::detail::empty_cuda({static_cast<int64_t>(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters))},
110+
torch::kByte, input.device(), std::nullopt);
111+
counters_tensor.zero_();
112+
auto* counters
113+
= reinterpret_cast<tensorrt_llm::kernels::WarpSpecializedCounters*>(counters_tensor.mutable_data_ptr());
114+
115+
auto stream = at::cuda::getCurrentCUDAStream(device);
116+
117+
#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \
118+
do \
119+
{ \
120+
using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam<T>; \
121+
tensorrt_llm::kernels::WarpSpecializedParam<Param> param; \
122+
param.normed_output = reinterpret_cast<uint32_t*>(normed_output.data_ptr()); \
123+
param.output = reinterpret_cast<T*>(output.data_ptr()); \
124+
param.input = const_cast<T*>(reinterpret_cast<T const*>(input.data_ptr())); \
125+
param.sf_scale = sfScalePtr; \
126+
param.sf_out = reinterpret_cast<uint32_t*>(sf_out.data_ptr()); \
127+
param.residual = reinterpret_cast<T const*>(residual.data_ptr()); \
128+
param.bias = nullptr; \
129+
param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \
130+
param.beta = nullptr; \
131+
param.m = static_cast<int>(m); \
132+
param.n = static_cast<int>(n); \
133+
param.layernorm_eps = 1e-5f; \
134+
param.stream = stream; \
135+
param.counters = counters; \
136+
tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \
137+
} while (0)
138+
139+
if (input.scalar_type() == at::ScalarType::Half)
140+
{
141+
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half);
142+
}
143+
else if (input.scalar_type() == at::ScalarType::BFloat16)
144+
{
145+
#ifdef ENABLE_BF16
146+
LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16);
147+
#else
148+
C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input.");
149+
#endif
150+
}
151+
else
152+
{
153+
C10_THROW_ERROR(
154+
NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16.");
155+
}
156+
157+
#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT
158+
159+
// No explicit sync needed - kernel runs asynchronously on the stream
160+
// counters_tensor will be freed when it goes out of scope (after stream sync at Python level)
161+
162+
return std::make_tuple(normed_output, output, sf_out);
163+
}
164+
165+
} // namespace torch_ext
166+
167+
TORCH_LIBRARY_FRAGMENT(trtllm, m)
168+
{
169+
m.def(
170+
"fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, "
171+
"Tensor? sf_scale, bool use_rms_norm=True) -> (Tensor, Tensor, Tensor)");
172+
}
173+
174+
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
175+
{
176+
m.impl("fused_add_rms_norm_quant", &torch_ext::fused_add_rms_norm_quant);
177+
}

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1621,3 +1621,52 @@ def record_stream(tensor: torch.Tensor, stream_id: int) -> None:
16211621
stream = get_stream(stream_id)
16221622
assert stream is not None
16231623
tensor.record_stream(stream)
1624+
1625+
1626+
def fused_add_rms_norm_quant(
1627+
input: torch.Tensor,
1628+
residual: torch.Tensor,
1629+
gamma: torch.Tensor,
1630+
sf_scale: Optional[torch.Tensor],
1631+
use_rms_norm: bool = True,
1632+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1633+
"""Fused Add + RMSNorm/LayerNorm + FP4 Quantization kernel.
1634+
1635+
Args:
1636+
input: [M, N] input tensor (fp16/bf16)
1637+
residual: [M, N] residual tensor (fp16/bf16)
1638+
gamma: [N] normalization weight (fp16/bf16)
1639+
sf_scale: [1] optional scale factor for FP4 quantization (float32)
1640+
use_rms_norm: if True use RMSNorm, else use LayerNorm
1641+
1642+
Returns:
1643+
normed_output_fp4: [M, N/8] FP4 quantized normalized output (int32, packed)
1644+
output: [M, N] pre-norm output (input + residual), same dtype as input
1645+
sf_out: scale factors for FP4 quantization (uint8), swizzled layout
1646+
1647+
Note:
1648+
This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU.
1649+
Hidden dimension N must be >= 2048 and <= 16384.
1650+
"""
1651+
return torch.ops.trtllm.fused_add_rms_norm_quant(input, residual, gamma,
1652+
sf_scale, use_rms_norm)
1653+
1654+
1655+
@torch.library.register_fake("trtllm::fused_add_rms_norm_quant")
1656+
def _fused_add_rms_norm_quant_fake(
1657+
input: torch.Tensor,
1658+
residual: torch.Tensor,
1659+
gamma: torch.Tensor,
1660+
sf_scale: Optional[torch.Tensor],
1661+
use_rms_norm: bool = True,
1662+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1663+
m, n = input.shape
1664+
# normed_output_fp4: [M, N/8] as int32 (8 FP4 values packed per int32)
1665+
normed_output_fp4 = input.new_empty((m, n // 8), dtype=torch.int32)
1666+
# output: [M, N] pre-norm output, same dtype as input
1667+
output = input.new_empty((m, n), dtype=input.dtype)
1668+
# sf_out: scale factors, swizzled layout
1669+
sf_vec_size = 16
1670+
sf_size = ((m + 127) // 128) * 128 * ((n // sf_vec_size + 3) // 4) * 4
1671+
sf_out = input.new_empty((sf_size, ), dtype=torch.uint8)
1672+
return normed_output_fp4, output, sf_out

0 commit comments

Comments
 (0)