|
| 1 | +/* |
| 2 | + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. |
| 3 | + * |
| 4 | + * Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | + * you may not use this file except in compliance with the License. |
| 6 | + * You may obtain a copy of the License at |
| 7 | + * |
| 8 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | + * |
| 10 | + * Unless required by applicable law or agreed to in writing, software |
| 11 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | + * See the License for the specific language governing permissions and |
| 14 | + * limitations under the License. |
| 15 | + */ |
| 16 | + |
| 17 | +#include "tensorrt_llm/common/envUtils.h" |
| 18 | +#include "tensorrt_llm/kernels/helixKernels.h" |
| 19 | + |
| 20 | +#include <cstdint> |
| 21 | +#include <cstdio> |
| 22 | + |
| 23 | +#include <cooperative_groups.h> |
| 24 | +#include <cuda_bf16.h> |
| 25 | +#include <cuda_fp16.h> |
| 26 | +#include <cuda_runtime.h> |
| 27 | + |
| 28 | +using namespace tensorrt_llm::common; |
| 29 | + |
| 30 | +namespace cg = cooperative_groups; |
| 31 | + |
| 32 | +namespace tensorrt_llm |
| 33 | +{ |
| 34 | +namespace kernels |
| 35 | +{ |
| 36 | +static constexpr int WARP_SIZE = 32; |
| 37 | + |
| 38 | +// Utility: warp-level corrected sum |
| 39 | +template <int N> |
| 40 | +__device__ inline void warpReduceCorrectedSum(float (&correctedVal)[N], float (&maxVal)[N], float (&sumVal)[N]) |
| 41 | +{ |
| 42 | + float warp_max = maxVal[0]; |
| 43 | +#pragma unroll |
| 44 | + for (int nn = 1; nn < N; ++nn) |
| 45 | + warp_max = fmaxf(warp_max, maxVal[nn]); |
| 46 | +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 1000 && defined(__CUDA_ARCH_FEAT_SM100_ALL)) |
| 47 | + asm("redux.sync.max.f32 %0, %1, 0xffffffff;\n" : "=f"(warp_max) : "f"(warp_max)); |
| 48 | +#else |
| 49 | +#pragma unroll |
| 50 | + for (int offset = 1; offset < WARP_SIZE; offset *= 2) |
| 51 | + warp_max = fmaxf(warp_max, __shfl_xor_sync(0xffffffff, warp_max, offset)); |
| 52 | +#endif |
| 53 | + float global_sum = 0.F; |
| 54 | + float corrected_max_exp[N]; |
| 55 | +#pragma unroll |
| 56 | + for (int nn = 0; nn < N; ++nn) |
| 57 | + { |
| 58 | + corrected_max_exp[nn] = sumVal[nn] * expf(maxVal[nn] - warp_max); |
| 59 | + global_sum += corrected_max_exp[nn]; |
| 60 | + } |
| 61 | +#pragma unroll |
| 62 | + for (int offset = 1; offset < WARP_SIZE; offset *= 2) |
| 63 | + global_sum += __shfl_xor_sync(0xffffffff, global_sum, offset); |
| 64 | + auto norm = 1.F / global_sum; |
| 65 | +#pragma unroll |
| 66 | + for (int nn = 0; nn < N; ++nn) |
| 67 | + correctedVal[nn] = corrected_max_exp[nn] * norm; |
| 68 | +} |
| 69 | + |
| 70 | +static constexpr int MAX_CP_VAL_PER_THREAD = 8; |
| 71 | +static constexpr int MAX_CP = WARP_SIZE * MAX_CP_VAL_PER_THREAD; |
| 72 | +static constexpr int BYTES_O_PER_THREAD = 16; |
| 73 | +static constexpr int NUM_PRE_LOAD = 8; |
| 74 | + |
| 75 | +// Kernel: fused helix post-processing |
| 76 | +// output: [num_tokens, num_heads * kv_lora_rank] (half) |
| 77 | +// gathered_o: [cp_size, num_tokens, num_heads * kv_lora_rank] (half) |
| 78 | +// gathered_stats: [cp_size, num_tokens, num_heads, 2] (fp32) |
| 79 | +// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc |
| 80 | +// which may have longer latency |
| 81 | +template <typename T> |
| 82 | +__global__ void helix_postprocess_kernel( |
| 83 | + T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank) |
| 84 | +{ |
| 85 | + // Each block processes one (token, head) |
| 86 | + // gridDim.x: num_tokens, gridDim.y: num_heads |
| 87 | + // there are two separate types of warps: |
| 88 | + // warp 0 calculates the correction values (one per cp_size) |
| 89 | + // all other warps pre-load the gathered_o elements for the current token/head |
| 90 | + // and once warp 0 is done, all other warps can start accumulating the output |
| 91 | + static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T); |
| 92 | + |
| 93 | + int tok_idx = blockIdx.x; |
| 94 | + int head_idx = blockIdx.y; |
| 95 | + int num_tokens = gridDim.x; |
| 96 | + int num_heads = gridDim.y; |
| 97 | + |
| 98 | + int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD; |
| 99 | + __shared__ float smem_correction[MAX_CP]; |
| 100 | + |
| 101 | + int lane_idx = threadIdx.x % WARP_SIZE; |
| 102 | + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); |
| 103 | + // here we have to wait for memory operations of the previous kernel to complete |
| 104 | +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) |
| 105 | + cudaGridDependencySynchronize(); |
| 106 | +#endif |
| 107 | + |
| 108 | + if (warp_idx == 0) |
| 109 | + { |
| 110 | + // the warp collectively calculates the correction values |
| 111 | + float max_values[MAX_CP_VAL_PER_THREAD]; |
| 112 | + float sum_values[MAX_CP_VAL_PER_THREAD]; |
| 113 | +#pragma unroll |
| 114 | + for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx) |
| 115 | + { |
| 116 | + auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx; |
| 117 | + auto stats_offset = cp_idx * num_tokens * num_heads + tok_idx * num_heads + head_idx; |
| 118 | + float2 stats = cp_idx < cp_size ? gathered_stats[stats_offset] : make_float2(-INFINITY, 0.F); |
| 119 | + max_values[cp_val_idx] = stats.x; |
| 120 | + sum_values[cp_val_idx] = stats.y; |
| 121 | + } |
| 122 | + float corrected_values[MAX_CP_VAL_PER_THREAD]; |
| 123 | + warpReduceCorrectedSum(corrected_values, max_values, sum_values); |
| 124 | +#pragma unroll |
| 125 | + for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx) |
| 126 | + { |
| 127 | + auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx; |
| 128 | + smem_correction[cp_idx] = corrected_values[cp_val_idx]; |
| 129 | + } |
| 130 | + cg::this_thread_block().sync(); |
| 131 | + } |
| 132 | + else |
| 133 | + { |
| 134 | + // all other warps pre-load the gathered_o elements for the current token/head |
| 135 | + auto const* gathered_o_off = gathered_o + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank; |
| 136 | + // we subtract WARP_SIZE because first warp is not participating here |
| 137 | + gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD; |
| 138 | + float4 const* gathered_o_16b = reinterpret_cast<float4 const*>(gathered_o_off); |
| 139 | + auto gathered_16b_stride = (num_tokens * num_heads * kv_lora_rank) / NUM_O_PER_THREAD; |
| 140 | + T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD]; |
| 141 | +#pragma unroll |
| 142 | + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) |
| 143 | + { |
| 144 | + auto val |
| 145 | + = cp_idx < cp_size ? gathered_o_16b[cp_idx * gathered_16b_stride] : make_float4(0.F, 0.F, 0.F, 0.F); |
| 146 | + *reinterpret_cast<float4*>(vals[cp_idx]) = val; |
| 147 | + } |
| 148 | + float final_sum[NUM_O_PER_THREAD]; |
| 149 | +#pragma unroll |
| 150 | + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) |
| 151 | + { |
| 152 | + final_sum[o_idx] = 0.F; |
| 153 | + } |
| 154 | + cg::this_thread_block().sync(); |
| 155 | + |
| 156 | + // here we can trigger the dependent kernels to start |
| 157 | +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) |
| 158 | + cudaTriggerProgrammaticLaunchCompletion(); |
| 159 | +#endif |
| 160 | + |
| 161 | + float corr_vals[NUM_PRE_LOAD]; |
| 162 | +#pragma unroll |
| 163 | + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) |
| 164 | + { |
| 165 | + corr_vals[cp_idx] = smem_correction[cp_idx]; |
| 166 | + } |
| 167 | + |
| 168 | + for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD) |
| 169 | + { |
| 170 | +#pragma unroll |
| 171 | + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx) |
| 172 | + { |
| 173 | +#pragma unroll |
| 174 | + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) |
| 175 | + { |
| 176 | + final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx]; |
| 177 | + } |
| 178 | + } |
| 179 | +#pragma unroll |
| 180 | + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx) |
| 181 | + { |
| 182 | + *reinterpret_cast<float4*>(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size |
| 183 | + ? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride] |
| 184 | + : make_float4(0.F, 0.F, 0.F, 0.F); |
| 185 | + corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F; |
| 186 | + } |
| 187 | + } |
| 188 | +#pragma unroll |
| 189 | + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) |
| 190 | + { |
| 191 | +#pragma unroll |
| 192 | + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) |
| 193 | + { |
| 194 | + final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx]; |
| 195 | + } |
| 196 | + } |
| 197 | + T output_typed[NUM_O_PER_THREAD]; |
| 198 | +#pragma unroll |
| 199 | + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) |
| 200 | + { |
| 201 | + output_typed[o_idx] = static_cast<T>(final_sum[o_idx]); |
| 202 | + } |
| 203 | + auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank; |
| 204 | + output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD; |
| 205 | + *reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed); |
| 206 | + } |
| 207 | +} |
| 208 | + |
| 209 | +template <typename T> |
| 210 | +void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream) |
| 211 | +{ |
| 212 | + // Check that gathered_o is 16-byte aligned |
| 213 | + TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0, |
| 214 | + "gathered_o must be 16-byte aligned for async memcpy"); |
| 215 | + // Check that kv_lora_rank * sizeof(T) is a multiple of 16 |
| 216 | + TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0, |
| 217 | + "kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy"); |
| 218 | + // Check that cp_size is not larger than the max fallback CP size |
| 219 | + TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size"); |
| 220 | + |
| 221 | + auto* kernel_instance = &helix_postprocess_kernel<T>; |
| 222 | + cudaLaunchConfig_t config; |
| 223 | + config.gridDim = dim3(params.num_tokens, params.num_heads); |
| 224 | + config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16; |
| 225 | + config.dynamicSmemBytes = 0; |
| 226 | + config.stream = stream; |
| 227 | + cudaLaunchAttribute attrs[1]; |
| 228 | + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; |
| 229 | + attrs[0].val.programmaticStreamSerializationAllowed = tensorrt_llm::common::getEnvEnablePDL(); |
| 230 | + config.numAttrs = 1; |
| 231 | + config.attrs = attrs; |
| 232 | + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o, |
| 233 | + params.gathered_stats, params.cp_size, params.kv_lora_rank)); |
| 234 | +} |
| 235 | + |
| 236 | +#define INSTANTIATE_POST_PROC(T) \ |
| 237 | + template void helixPostProcess<T>(HelixPostProcParams<T> const& params, cudaStream_t stream); |
| 238 | + |
| 239 | +INSTANTIATE_POST_PROC(__half); |
| 240 | +INSTANTIATE_POST_PROC(__nv_bfloat16); |
| 241 | + |
| 242 | +} // namespace kernels |
| 243 | +} // namespace tensorrt_llm |
0 commit comments