|
| 1 | +/* |
| 2 | + * Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "tensorrt_llm/common/cudaUtils.h" |
| 18 | +#include "tensorrt_llm/kernels/fusedLayernormKernels/layernorm_param.h" |
| 19 | +#include "tensorrt_llm/kernels/fusedLayernormKernels/ws_layernorm.h" |
| 20 | +#include "tensorrt_llm/kernels/quantization.h" |
| 21 | +#include "tensorrt_llm/thop/thUtils.h" |
| 22 | + |
| 23 | +#include <ATen/Functions.h> |
| 24 | +#include <ATen/cuda/CUDAContext.h> |
| 25 | +#include <ATen/cuda/EmptyTensor.h> |
| 26 | + |
| 27 | +#include <cuda_bf16.h> |
| 28 | +#include <cuda_fp16.h> |
| 29 | + |
| 30 | +#include <cstdint> |
| 31 | +#include <optional> |
| 32 | +#include <tuple> |
| 33 | +#include <unordered_map> |
| 34 | + |
| 35 | +TRTLLM_NAMESPACE_BEGIN |
| 36 | + |
| 37 | +namespace torch_ext |
| 38 | +{ |
| 39 | + |
| 40 | +// Fused Add + RMSNorm + FP4 Quantization kernel |
| 41 | +// input: [M, N] - input tensor (fp16/bf16) |
| 42 | +// residual: [M, N] - residual tensor (fp16/bf16) |
| 43 | +// gamma: [N] - RMSNorm weight (fp16/bf16) |
| 44 | +// sf_scale: [1] - optional scale factor for FP4 quantization (float) |
| 45 | +// use_rms_norm: bool - if true use RMSNorm, else use LayerNorm |
| 46 | +// Returns: |
| 47 | +// normed_output: [M, N/8] - FP4 quantized normalized output (uint32_t, packed) |
| 48 | +// output: [M, N] - pre-norm output (input + residual), same dtype as input |
| 49 | +// sf_out: scale factors for FP4 (uint8_t), swizzled layout |
| 50 | +// |
| 51 | +// NOTE: This kernel requires SM90 (Hopper) or SM100 (Blackwell) GPU architecture. |
| 52 | +// NOTE: Hidden dimension N must be >= 2048 and <= 16384. |
| 53 | +std::tuple<at::Tensor, at::Tensor, at::Tensor> fused_add_rms_norm_quant(at::Tensor const& input, |
| 54 | + at::Tensor const& residual, at::Tensor const& gamma, std::optional<at::Tensor> const& sf_scale, bool use_rms_norm, |
| 55 | + double eps) |
| 56 | +{ |
| 57 | + CHECK_TH_CUDA(input); |
| 58 | + CHECK_CONTIGUOUS(input); |
| 59 | + CHECK_TH_CUDA(residual); |
| 60 | + CHECK_CONTIGUOUS(residual); |
| 61 | + CHECK_TH_CUDA(gamma); |
| 62 | + CHECK_CONTIGUOUS(gamma); |
| 63 | + |
| 64 | + // Check GPU architecture - kernel requires SM90+ (Hopper/Blackwell) |
| 65 | + auto const device = input.get_device(); |
| 66 | + cudaDeviceProp props; |
| 67 | + AT_CUDA_CHECK(cudaGetDeviceProperties(&props, device)); |
| 68 | + TORCH_CHECK(props.major >= 9, |
| 69 | + "fused_add_rms_norm_quant requires SM90 (Hopper) or newer GPU architecture. " |
| 70 | + "Current device: sm_", |
| 71 | + props.major, props.minor); |
| 72 | + |
| 73 | + auto const& inputShape = input.sizes(); |
| 74 | + auto const& rank = inputShape.size(); |
| 75 | + |
| 76 | + TORCH_CHECK(rank == 2, "input should be 2D tensor [M, N]."); |
| 77 | + TORCH_CHECK(residual.sizes() == inputShape, "residual shape must match input shape."); |
| 78 | + |
| 79 | + int64_t const m = inputShape[0]; |
| 80 | + int64_t const n = inputShape[1]; |
| 81 | + // Some warp-specialized kernels may issue vectorized stores that assume M is padded. |
| 82 | + // Allocate a bit of extra space to avoid out-of-bounds writes when M is not a multiple of 8. |
| 83 | + int64_t const m_padded = ((m + 15) / 16) * 16; |
| 84 | + |
| 85 | + TORCH_CHECK(gamma.sizes()[0] == n, "gamma size must match hidden dimension N."); |
| 86 | + TORCH_CHECK(n >= 2048, "Hidden dimension N must be >= 2048 (kernel constraint)."); |
| 87 | + TORCH_CHECK(n <= 16384, "Hidden dimension N must be <= 16384."); |
| 88 | + TORCH_CHECK(n % 16 == 0, "Hidden dimension N must be divisible by 16 for FP4 quantization."); |
| 89 | + |
| 90 | + // Validate sf_scale if provided |
| 91 | + float* sfScalePtr = nullptr; |
| 92 | + if (sf_scale.has_value()) |
| 93 | + { |
| 94 | + CHECK_INPUT(sf_scale.value(), torch::kFloat32); |
| 95 | + sfScalePtr = sf_scale.value().data_ptr<float>(); |
| 96 | + } |
| 97 | + |
| 98 | + // Allocate output tensors |
| 99 | + // normed_output: FP4 packed output [M, N/8] as uint32_t (8 FP4 values packed per uint32) |
| 100 | + // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. |
| 101 | + at::Tensor normed_output_padded |
| 102 | + = at::detail::empty_cuda({m_padded, n / 8}, torch::kInt32, input.device(), std::nullopt); |
| 103 | + at::Tensor normed_output = (m_padded == m) ? normed_output_padded : normed_output_padded.narrow(0, 0, m); |
| 104 | + |
| 105 | + // output: pre-norm output (input + residual) [M, N], same dtype as input |
| 106 | + // NOTE: allocate [M_padded, ...] to avoid OOB writes; return a view of [M, ...] to keep API stable. |
| 107 | + at::Tensor output_padded = at::detail::empty_cuda({m_padded, n}, input.scalar_type(), input.device(), std::nullopt); |
| 108 | + at::Tensor output = (m_padded == m) ? output_padded : output_padded.narrow(0, 0, m); |
| 109 | + |
| 110 | + // sf_out: scale factors for FP4, swizzled layout |
| 111 | + // sfVecSize = 16 for FP4 quantization (16 FP4 values share one scale factor) |
| 112 | + int64_t const sfVecSize = 16; |
| 113 | + int64_t const sfSize = tensorrt_llm::computeSwizzledLayoutSFSize(m, n / sfVecSize); |
| 114 | + at::Tensor sf_out = at::detail::empty_cuda({sfSize}, SF_DTYPE, input.device(), std::nullopt); |
| 115 | + |
| 116 | + // Get number of SMs for persistent kernel |
| 117 | + static int const multiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); |
| 118 | + |
| 119 | + // Allocate counters for warp-specialized kernel using PyTorch allocator. |
| 120 | + // |
| 121 | + // NOTE: We cache this tensor to avoid per-call allocations. We use `thread_local` so |
| 122 | + // concurrent calls from different threads don't share the same counters buffer (which |
| 123 | + // could cause races across different CUDA streams). |
| 124 | + static thread_local std::unordered_map<int, at::Tensor> counters_tensor_cache; |
| 125 | + auto& counters_tensor = counters_tensor_cache[device]; |
| 126 | + int64_t const counters_bytes = static_cast<int64_t>(sizeof(tensorrt_llm::kernels::WarpSpecializedCounters)); |
| 127 | + if (!counters_tensor.defined() || counters_tensor.numel() != counters_bytes) |
| 128 | + { |
| 129 | + counters_tensor = at::detail::empty_cuda({counters_bytes}, torch::kByte, input.device(), std::nullopt); |
| 130 | + counters_tensor.zero_(); |
| 131 | + } |
| 132 | + auto* counters |
| 133 | + = reinterpret_cast<tensorrt_llm::kernels::WarpSpecializedCounters*>(counters_tensor.mutable_data_ptr()); |
| 134 | + |
| 135 | + auto stream = at::cuda::getCurrentCUDAStream(device); |
| 136 | + |
| 137 | +#define LAUNCH_FUSED_ADD_RMS_NORM_QUANT(T) \ |
| 138 | + do \ |
| 139 | + { \ |
| 140 | + using Param = tensorrt_llm::kernels::GeneralFP4AddBiasResidualPreLayerNormParam<T>; \ |
| 141 | + tensorrt_llm::kernels::WarpSpecializedParam<Param> param; \ |
| 142 | + param.normed_output = reinterpret_cast<uint32_t*>(normed_output.data_ptr()); \ |
| 143 | + param.output = reinterpret_cast<T*>(output.data_ptr()); \ |
| 144 | + param.input = const_cast<T*>(reinterpret_cast<T const*>(input.data_ptr())); \ |
| 145 | + param.sf_scale = sfScalePtr; \ |
| 146 | + param.sf_out = reinterpret_cast<uint32_t*>(sf_out.data_ptr()); \ |
| 147 | + param.residual = reinterpret_cast<T const*>(residual.data_ptr()); \ |
| 148 | + param.bias = nullptr; \ |
| 149 | + param.gamma = reinterpret_cast<T const*>(gamma.data_ptr()); \ |
| 150 | + param.beta = nullptr; \ |
| 151 | + param.m = static_cast<int>(m); \ |
| 152 | + param.n = static_cast<int>(n); \ |
| 153 | + param.layernorm_eps = static_cast<float>(eps); \ |
| 154 | + param.stream = stream; \ |
| 155 | + param.counters = counters; \ |
| 156 | + tensorrt_llm::kernels::invokeWSLayerNorm<Param>(param, use_rms_norm, multiProcessorCount); \ |
| 157 | + } while (0) |
| 158 | + |
| 159 | + if (input.scalar_type() == at::ScalarType::Half) |
| 160 | + { |
| 161 | + LAUNCH_FUSED_ADD_RMS_NORM_QUANT(half); |
| 162 | + } |
| 163 | + else if (input.scalar_type() == at::ScalarType::BFloat16) |
| 164 | + { |
| 165 | +#ifdef ENABLE_BF16 |
| 166 | + LAUNCH_FUSED_ADD_RMS_NORM_QUANT(__nv_bfloat16); |
| 167 | +#else |
| 168 | + C10_THROW_ERROR(NotImplementedError, "BFloat16 must be enabled for fused_add_rms_norm_quant with bf16 input."); |
| 169 | +#endif |
| 170 | + } |
| 171 | + else |
| 172 | + { |
| 173 | + C10_THROW_ERROR( |
| 174 | + NotImplementedError, "fused_add_rms_norm_quant only supports input tensor with dtypes fp16/bf16."); |
| 175 | + } |
| 176 | + |
| 177 | +#undef LAUNCH_FUSED_ADD_RMS_NORM_QUANT |
| 178 | + |
| 179 | + // No explicit sync needed - kernel runs asynchronously on the stream |
| 180 | + return std::make_tuple(normed_output, output, sf_out); |
| 181 | +} |
| 182 | + |
| 183 | +} // namespace torch_ext |
| 184 | + |
| 185 | +TRTLLM_NAMESPACE_END |
| 186 | + |
| 187 | +TORCH_LIBRARY_FRAGMENT(trtllm, m) |
| 188 | +{ |
| 189 | + m.def( |
| 190 | + "fused_add_rms_norm_quant(Tensor input, Tensor residual, Tensor gamma, " |
| 191 | + "Tensor? sf_scale, bool use_rms_norm=True, float eps=1e-5) -> (Tensor, Tensor, Tensor)"); |
| 192 | +} |
| 193 | + |
| 194 | +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) |
| 195 | +{ |
| 196 | + m.impl("fused_add_rms_norm_quant", &tensorrt_llm::torch_ext::fused_add_rms_norm_quant); |
| 197 | +} |
0 commit comments