|
| 1 | +// SPDX-License-Identifier: Apache-2.0 |
| 2 | +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | + |
| 4 | +#pragma once |
| 5 | + |
| 6 | +#include <c10/util/BFloat16.h> |
| 7 | +#include <c10/util/Half.h> |
| 8 | +#include <cassert> |
| 9 | + |
| 10 | +#ifdef USE_ROCM |
| 11 | + #include <hip/hip_runtime.h> |
| 12 | +#else |
| 13 | + #include <cuda_bf16.h> |
| 14 | + #include <cuda_fp16.h> |
| 15 | + #include <cuda_runtime.h> |
| 16 | +#endif |
| 17 | + |
| 18 | +// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which |
| 19 | +// together enable 256-bit (v8.u32) PTX load/store instructions. |
| 20 | +// Use for PTX instruction selection with architecture fallback paths. |
| 21 | +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ |
| 22 | + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 |
| 23 | + #define VLLM_256B_PTX_ENABLED 1 |
| 24 | +#else |
| 25 | + #define VLLM_256B_PTX_ENABLED 0 |
| 26 | +#endif |
| 27 | + |
| 28 | +namespace vllm { |
| 29 | + |
| 30 | +// ============================================================ |
| 31 | +// Types and traits |
| 32 | +// ============================================================ |
| 33 | + |
| 34 | +// 256-bit (32-byte) aligned vector type: 8 x uint32_t |
| 35 | +struct alignas(32) u32x8_t { |
| 36 | + uint32_t d[8]; |
| 37 | +}; |
| 38 | + |
| 39 | +// VecTraits — select between 128-bit (int4) and 256-bit |
| 40 | +// (u32x8_t) vector types at compile time. |
| 41 | +template <bool support_256> |
| 42 | +struct VecTraits; |
| 43 | + |
| 44 | +template <> |
| 45 | +struct VecTraits<true> { |
| 46 | + static constexpr int ARCH_MAX_VEC_SIZE = 32; |
| 47 | + using vec_t = u32x8_t; |
| 48 | +}; |
| 49 | + |
| 50 | +template <> |
| 51 | +struct VecTraits<false> { |
| 52 | + static constexpr int ARCH_MAX_VEC_SIZE = 16; |
| 53 | + using vec_t = int4; |
| 54 | +}; |
| 55 | + |
| 56 | +// PackedTypeConverter — map between CUDA scalar and packed types |
| 57 | +// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. |
| 58 | +template <typename T> |
| 59 | +struct PackedTypeConverter { |
| 60 | + static_assert(sizeof(T) == 0, |
| 61 | + "PackedTypeConverter is not specialized for this type."); |
| 62 | +}; |
| 63 | + |
| 64 | +template <> |
| 65 | +struct PackedTypeConverter<half2> { |
| 66 | + using Type = half; |
| 67 | +}; |
| 68 | + |
| 69 | +template <> |
| 70 | +struct PackedTypeConverter<half> { |
| 71 | + using Type = half2; |
| 72 | +}; |
| 73 | + |
| 74 | +template <> |
| 75 | +struct PackedTypeConverter<__nv_bfloat162> { |
| 76 | + using Type = __nv_bfloat16; |
| 77 | +}; |
| 78 | + |
| 79 | +template <> |
| 80 | +struct PackedTypeConverter<__nv_bfloat16> { |
| 81 | + using Type = __nv_bfloat162; |
| 82 | +}; |
| 83 | + |
| 84 | +template <> |
| 85 | +struct PackedTypeConverter<float> { |
| 86 | + using Type = float2; |
| 87 | +}; |
| 88 | + |
| 89 | +template <> |
| 90 | +struct PackedTypeConverter<float2> { |
| 91 | + using Type = float; |
| 92 | +}; |
| 93 | + |
| 94 | +template <> |
| 95 | +struct PackedTypeConverter<c10::Half> { |
| 96 | + using Type = half2; |
| 97 | +}; |
| 98 | + |
| 99 | +template <> |
| 100 | +struct PackedTypeConverter<c10::BFloat16> { |
| 101 | + using Type = __nv_bfloat162; |
| 102 | +}; |
| 103 | + |
| 104 | +// CUDATypeConverter — map PyTorch scalar types to CUDA scalar |
| 105 | +// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16 |
| 106 | +template <typename T> |
| 107 | +struct CUDATypeConverter { |
| 108 | + using Type = T; |
| 109 | +}; |
| 110 | + |
| 111 | +template <> |
| 112 | +struct CUDATypeConverter<c10::Half> { |
| 113 | + using Type = half; |
| 114 | +}; |
| 115 | + |
| 116 | +template <> |
| 117 | +struct CUDATypeConverter<c10::BFloat16> { |
| 118 | + using Type = __nv_bfloat16; |
| 119 | +}; |
| 120 | + |
| 121 | +// PackedVec — typed vector container for packed element access. |
| 122 | +// Derives alignment and element count from VecTraits. |
| 123 | +// Type is the CUDA scalar type (e.g. half, __nv_bfloat16). |
| 124 | +template <class Type, bool use_256b> |
| 125 | +struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec { |
| 126 | + static constexpr int NUM_ELTS = |
| 127 | + VecTraits<use_256b>::ARCH_MAX_VEC_SIZE / |
| 128 | + sizeof(typename PackedTypeConverter<Type>::Type); |
| 129 | + typename PackedTypeConverter<Type>::Type elts[NUM_ELTS]; |
| 130 | +}; |
| 131 | + |
| 132 | +// ============================================================ |
| 133 | +// Load / store primitives |
| 134 | +// ============================================================ |
| 135 | + |
| 136 | +// 256-bit load / store — SM100+ only (PTX v8 instructions). |
| 137 | +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { |
| 138 | +#if VLLM_256B_PTX_ENABLED |
| 139 | + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" |
| 140 | + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), |
| 141 | + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) |
| 142 | + : "l"(ptr)); |
| 143 | +#else |
| 144 | + assert(false && "ld256 requires SM100+ with CUDA 12.9+"); |
| 145 | +#endif |
| 146 | +} |
| 147 | + |
| 148 | +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { |
| 149 | +#if VLLM_256B_PTX_ENABLED |
| 150 | + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" |
| 151 | + : |
| 152 | + : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), |
| 153 | + "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), |
| 154 | + "r"(val.d[7]) |
| 155 | + : "memory"); |
| 156 | +#else |
| 157 | + assert(false && "st256 requires SM100+ with CUDA 12.9+"); |
| 158 | +#endif |
| 159 | +} |
| 160 | + |
| 161 | +// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). |
| 162 | +// Non-template overloads above are preferred for u32x8_t. |
| 163 | +template <typename T> |
| 164 | +__device__ __forceinline__ void ld256(T& val, const T* ptr) { |
| 165 | + static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type"); |
| 166 | + ld256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<const u32x8_t*>(ptr)); |
| 167 | +} |
| 168 | + |
| 169 | +template <typename T> |
| 170 | +__device__ __forceinline__ void st256(T& val, T* ptr) { |
| 171 | + static_assert(sizeof(T) == 32, "st256 requires a 32-byte type"); |
| 172 | + st256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<u32x8_t*>(ptr)); |
| 173 | +} |
| 174 | + |
| 175 | +// 128-bit load / store via __ldg (read-only cache hint). |
| 176 | +template <typename T> |
| 177 | +__device__ __forceinline__ void ld128(T& val, const T* ptr) { |
| 178 | + static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); |
| 179 | + *reinterpret_cast<int4*>(&val) = __ldg(reinterpret_cast<const int4*>(ptr)); |
| 180 | +} |
| 181 | + |
| 182 | +template <typename T> |
| 183 | +__device__ __forceinline__ void st128(T& val, T* ptr) { |
| 184 | + static_assert(sizeof(T) == 16, "st128 requires a 16-byte type"); |
| 185 | + *reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val); |
| 186 | +} |
| 187 | + |
| 188 | +// 256-bit cache-streaming (.cs) load / store — SM100+ only. |
| 189 | +__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { |
| 190 | +#if VLLM_256B_PTX_ENABLED |
| 191 | + u32x8_t val; |
| 192 | + asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" |
| 193 | + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), |
| 194 | + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) |
| 195 | + : "l"(addr)); |
| 196 | + return val; |
| 197 | +#else |
| 198 | + assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); |
| 199 | + return {}; |
| 200 | +#endif |
| 201 | +} |
| 202 | + |
| 203 | +__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { |
| 204 | +#if VLLM_256B_PTX_ENABLED |
| 205 | + asm volatile( |
| 206 | + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), |
| 207 | + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), |
| 208 | + "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); |
| 209 | +#else |
| 210 | + assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); |
| 211 | +#endif |
| 212 | +} |
| 213 | + |
| 214 | +// 32-bit cache-streaming (.cs) load / store — SM100+ only. |
| 215 | +__forceinline__ __device__ int ld32_cs(const int* addr) { |
| 216 | +#if VLLM_256B_PTX_ENABLED |
| 217 | + int val; |
| 218 | + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); |
| 219 | + return val; |
| 220 | +#else |
| 221 | + assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); |
| 222 | + return 0; |
| 223 | +#endif |
| 224 | +} |
| 225 | + |
| 226 | +__forceinline__ __device__ void st32_cs(int* addr, int val) { |
| 227 | +#if VLLM_256B_PTX_ENABLED |
| 228 | + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); |
| 229 | +#else |
| 230 | + assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); |
| 231 | +#endif |
| 232 | +} |
| 233 | + |
| 234 | +// Predicated 256-bit / 128-bit cache-global (.cg) loads. |
| 235 | +// Returns zero if pred is false. SM100+ only. |
| 236 | +__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, |
| 237 | + bool pred) { |
| 238 | +#if VLLM_256B_PTX_ENABLED |
| 239 | + asm volatile( |
| 240 | + "{\n" |
| 241 | + " .reg .pred pr;\n" |
| 242 | + " setp.ne.u32 pr, %8, 0;\n" |
| 243 | + " mov.u32 %0, 0;\n" |
| 244 | + " mov.u32 %1, 0;\n" |
| 245 | + " mov.u32 %2, 0;\n" |
| 246 | + " mov.u32 %3, 0;\n" |
| 247 | + " mov.u32 %4, 0;\n" |
| 248 | + " mov.u32 %5, 0;\n" |
| 249 | + " mov.u32 %6, 0;\n" |
| 250 | + " mov.u32 %7, 0;\n" |
| 251 | + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" |
| 252 | + "}\n" |
| 253 | + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), |
| 254 | + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) |
| 255 | + : "r"((int)pred), "l"(ptr)); |
| 256 | +#else |
| 257 | + assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+"); |
| 258 | +#endif |
| 259 | +} |
| 260 | + |
| 261 | +__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, |
| 262 | + bool pred) { |
| 263 | +#if VLLM_256B_PTX_ENABLED |
| 264 | + uint32_t r0, r1, r2, r3; |
| 265 | + |
| 266 | + asm volatile( |
| 267 | + "{\n" |
| 268 | + " .reg .pred pr;\n" |
| 269 | + " setp.ne.u32 pr, %4, 0;\n" |
| 270 | + " mov.u32 %0, 0;\n" |
| 271 | + " mov.u32 %1, 0;\n" |
| 272 | + " mov.u32 %2, 0;\n" |
| 273 | + " mov.u32 %3, 0;\n" |
| 274 | + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" |
| 275 | + "}\n" |
| 276 | + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) |
| 277 | + : "r"((int)pred), "l"(ptr)); |
| 278 | + |
| 279 | + val = uint4{r0, r1, r2, r3}; |
| 280 | +#else |
| 281 | + assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); |
| 282 | +#endif |
| 283 | +} |
| 284 | + |
| 285 | +// ============================================================ |
| 286 | +// Alignment helpers |
| 287 | +// ============================================================ |
| 288 | + |
| 289 | +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { |
| 290 | + return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0; |
| 291 | +} |
| 292 | + |
| 293 | +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { |
| 294 | + return (reinterpret_cast<uintptr_t>(ptr) & 31) == 0; |
| 295 | +} |
| 296 | + |
| 297 | +// ============================================================ |
| 298 | +// Packed type conversion and arithmetic |
| 299 | +// ============================================================ |
| 300 | + |
| 301 | +template <typename packed_t> |
| 302 | +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { |
| 303 | + if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) { |
| 304 | + return __bfloat1622float2(val); |
| 305 | + } else if constexpr (std::is_same_v<packed_t, __half2>) { |
| 306 | + return __half22float2(val); |
| 307 | + } else if constexpr (std::is_same_v<packed_t, float2>) { |
| 308 | + return float2(val); |
| 309 | + } |
| 310 | +} |
| 311 | + |
| 312 | +template <typename packed_t> |
| 313 | +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { |
| 314 | + if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) { |
| 315 | + return __float22bfloat162_rn(val); |
| 316 | + } else if constexpr (std::is_same_v<packed_t, __half2>) { |
| 317 | + return __float22half2_rn(val); |
| 318 | + } else if constexpr (std::is_same_v<packed_t, float2>) { |
| 319 | + return float2(val); |
| 320 | + } |
| 321 | +} |
| 322 | + |
| 323 | +template <typename packed_t> |
| 324 | +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, |
| 325 | + const packed_t& y) { |
| 326 | + if constexpr (std::is_same_v<packed_t, __nv_bfloat162> || |
| 327 | + std::is_same_v<packed_t, __half2>) { |
| 328 | + return __hmul2(x, y); |
| 329 | + } else if constexpr (std::is_same_v<packed_t, float2>) { |
| 330 | + return make_float2(x.x * y.x, x.y * y.y); |
| 331 | + } |
| 332 | +} |
| 333 | + |
| 334 | +} // namespace vllm |
0 commit comments