Skip to content

Commit a201ad7

Browse files
[Refactor][Kernel] Add global helper to deduplicate vectorized memory ops (vllm-project#35105)
Signed-off-by: LopezCastroRoberto <rocastro@redhat.com> Signed-off-by: LopezCastroRoberto <roberto.lopez.castro@udc.es> Signed-off-by: Roberto L. Castro <38211239+LopezCastroRoberto@users.noreply.github.com>
1 parent e369198 commit a201ad7

File tree

6 files changed

+474
-372
lines changed

6 files changed

+474
-372
lines changed

csrc/activation_kernels.cu

Lines changed: 90 additions & 205 deletions
Large diffs are not rendered by default.

csrc/cuda_vec_utils.cuh

Lines changed: 334 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,334 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
#pragma once
5+
6+
#include <c10/util/BFloat16.h>
7+
#include <c10/util/Half.h>
8+
#include <cassert>
9+
10+
#ifdef USE_ROCM
11+
#include <hip/hip_runtime.h>
12+
#else
13+
#include <cuda_bf16.h>
14+
#include <cuda_fp16.h>
15+
#include <cuda_runtime.h>
16+
#endif
17+
18+
// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which
19+
// together enable 256-bit (v8.u32) PTX load/store instructions.
20+
// Use for PTX instruction selection with architecture fallback paths.
21+
#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \
22+
defined(CUDA_VERSION) && CUDA_VERSION >= 12090
23+
#define VLLM_256B_PTX_ENABLED 1
24+
#else
25+
#define VLLM_256B_PTX_ENABLED 0
26+
#endif
27+
28+
namespace vllm {
29+
30+
// ============================================================
31+
// Types and traits
32+
// ============================================================
33+
34+
// 256-bit (32-byte) aligned vector type: 8 x uint32_t
35+
struct alignas(32) u32x8_t {
36+
uint32_t d[8];
37+
};
38+
39+
// VecTraits — select between 128-bit (int4) and 256-bit
40+
// (u32x8_t) vector types at compile time.
41+
template <bool support_256>
42+
struct VecTraits;
43+
44+
template <>
45+
struct VecTraits<true> {
46+
static constexpr int ARCH_MAX_VEC_SIZE = 32;
47+
using vec_t = u32x8_t;
48+
};
49+
50+
template <>
51+
struct VecTraits<false> {
52+
static constexpr int ARCH_MAX_VEC_SIZE = 16;
53+
using vec_t = int4;
54+
};
55+
56+
// PackedTypeConverter — map between CUDA scalar and packed types
57+
// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc.
58+
template <typename T>
59+
struct PackedTypeConverter {
60+
static_assert(sizeof(T) == 0,
61+
"PackedTypeConverter is not specialized for this type.");
62+
};
63+
64+
template <>
65+
struct PackedTypeConverter<half2> {
66+
using Type = half;
67+
};
68+
69+
template <>
70+
struct PackedTypeConverter<half> {
71+
using Type = half2;
72+
};
73+
74+
template <>
75+
struct PackedTypeConverter<__nv_bfloat162> {
76+
using Type = __nv_bfloat16;
77+
};
78+
79+
template <>
80+
struct PackedTypeConverter<__nv_bfloat16> {
81+
using Type = __nv_bfloat162;
82+
};
83+
84+
template <>
85+
struct PackedTypeConverter<float> {
86+
using Type = float2;
87+
};
88+
89+
template <>
90+
struct PackedTypeConverter<float2> {
91+
using Type = float;
92+
};
93+
94+
template <>
95+
struct PackedTypeConverter<c10::Half> {
96+
using Type = half2;
97+
};
98+
99+
template <>
100+
struct PackedTypeConverter<c10::BFloat16> {
101+
using Type = __nv_bfloat162;
102+
};
103+
104+
// CUDATypeConverter — map PyTorch scalar types to CUDA scalar
105+
// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16
106+
template <typename T>
107+
struct CUDATypeConverter {
108+
using Type = T;
109+
};
110+
111+
template <>
112+
struct CUDATypeConverter<c10::Half> {
113+
using Type = half;
114+
};
115+
116+
template <>
117+
struct CUDATypeConverter<c10::BFloat16> {
118+
using Type = __nv_bfloat16;
119+
};
120+
121+
// PackedVec — typed vector container for packed element access.
122+
// Derives alignment and element count from VecTraits.
123+
// Type is the CUDA scalar type (e.g. half, __nv_bfloat16).
124+
template <class Type, bool use_256b>
125+
struct alignas(VecTraits<use_256b>::ARCH_MAX_VEC_SIZE) PackedVec {
126+
static constexpr int NUM_ELTS =
127+
VecTraits<use_256b>::ARCH_MAX_VEC_SIZE /
128+
sizeof(typename PackedTypeConverter<Type>::Type);
129+
typename PackedTypeConverter<Type>::Type elts[NUM_ELTS];
130+
};
131+
132+
// ============================================================
133+
// Load / store primitives
134+
// ============================================================
135+
136+
// 256-bit load / store — SM100+ only (PTX v8 instructions).
137+
__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) {
138+
#if VLLM_256B_PTX_ENABLED
139+
asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n"
140+
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
141+
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
142+
: "l"(ptr));
143+
#else
144+
assert(false && "ld256 requires SM100+ with CUDA 12.9+");
145+
#endif
146+
}
147+
148+
__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) {
149+
#if VLLM_256B_PTX_ENABLED
150+
asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n"
151+
:
152+
: "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]),
153+
"r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]),
154+
"r"(val.d[7])
155+
: "memory");
156+
#else
157+
assert(false && "st256 requires SM100+ with CUDA 12.9+");
158+
#endif
159+
}
160+
161+
// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec).
162+
// Non-template overloads above are preferred for u32x8_t.
163+
template <typename T>
164+
__device__ __forceinline__ void ld256(T& val, const T* ptr) {
165+
static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type");
166+
ld256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<const u32x8_t*>(ptr));
167+
}
168+
169+
template <typename T>
170+
__device__ __forceinline__ void st256(T& val, T* ptr) {
171+
static_assert(sizeof(T) == 32, "st256 requires a 32-byte type");
172+
st256(reinterpret_cast<u32x8_t&>(val), reinterpret_cast<u32x8_t*>(ptr));
173+
}
174+
175+
// 128-bit load / store via __ldg (read-only cache hint).
176+
template <typename T>
177+
__device__ __forceinline__ void ld128(T& val, const T* ptr) {
178+
static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type");
179+
*reinterpret_cast<int4*>(&val) = __ldg(reinterpret_cast<const int4*>(ptr));
180+
}
181+
182+
template <typename T>
183+
__device__ __forceinline__ void st128(T& val, T* ptr) {
184+
static_assert(sizeof(T) == 16, "st128 requires a 16-byte type");
185+
*reinterpret_cast<int4*>(ptr) = *reinterpret_cast<int4*>(&val);
186+
}
187+
188+
// 256-bit cache-streaming (.cs) load / store — SM100+ only.
189+
__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) {
190+
#if VLLM_256B_PTX_ENABLED
191+
u32x8_t val;
192+
asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];"
193+
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
194+
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
195+
: "l"(addr));
196+
return val;
197+
#else
198+
assert(false && "ld256_cs requires SM100+ with CUDA 12.9+");
199+
return {};
200+
#endif
201+
}
202+
203+
__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) {
204+
#if VLLM_256B_PTX_ENABLED
205+
asm volatile(
206+
"st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr),
207+
"r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]),
208+
"r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7]));
209+
#else
210+
assert(false && "st256_cs requires SM100+ with CUDA 12.9+");
211+
#endif
212+
}
213+
214+
// 32-bit cache-streaming (.cs) load / store — SM100+ only.
215+
__forceinline__ __device__ int ld32_cs(const int* addr) {
216+
#if VLLM_256B_PTX_ENABLED
217+
int val;
218+
asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr));
219+
return val;
220+
#else
221+
assert(false && "ld32_cs requires SM100+ with CUDA 12.9+");
222+
return 0;
223+
#endif
224+
}
225+
226+
__forceinline__ __device__ void st32_cs(int* addr, int val) {
227+
#if VLLM_256B_PTX_ENABLED
228+
asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val));
229+
#else
230+
assert(false && "st32_cs requires SM100+ with CUDA 12.9+");
231+
#endif
232+
}
233+
234+
// Predicated 256-bit / 128-bit cache-global (.cg) loads.
235+
// Returns zero if pred is false. SM100+ only.
236+
__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr,
237+
bool pred) {
238+
#if VLLM_256B_PTX_ENABLED
239+
asm volatile(
240+
"{\n"
241+
" .reg .pred pr;\n"
242+
" setp.ne.u32 pr, %8, 0;\n"
243+
" mov.u32 %0, 0;\n"
244+
" mov.u32 %1, 0;\n"
245+
" mov.u32 %2, 0;\n"
246+
" mov.u32 %3, 0;\n"
247+
" mov.u32 %4, 0;\n"
248+
" mov.u32 %5, 0;\n"
249+
" mov.u32 %6, 0;\n"
250+
" mov.u32 %7, 0;\n"
251+
" @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n"
252+
"}\n"
253+
: "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]),
254+
"=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7])
255+
: "r"((int)pred), "l"(ptr));
256+
#else
257+
assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+");
258+
#endif
259+
}
260+
261+
__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr,
262+
bool pred) {
263+
#if VLLM_256B_PTX_ENABLED
264+
uint32_t r0, r1, r2, r3;
265+
266+
asm volatile(
267+
"{\n"
268+
" .reg .pred pr;\n"
269+
" setp.ne.u32 pr, %4, 0;\n"
270+
" mov.u32 %0, 0;\n"
271+
" mov.u32 %1, 0;\n"
272+
" mov.u32 %2, 0;\n"
273+
" mov.u32 %3, 0;\n"
274+
" @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n"
275+
"}\n"
276+
: "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3)
277+
: "r"((int)pred), "l"(ptr));
278+
279+
val = uint4{r0, r1, r2, r3};
280+
#else
281+
assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+");
282+
#endif
283+
}
284+
285+
// ============================================================
286+
// Alignment helpers
287+
// ============================================================
288+
289+
__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) {
290+
return (reinterpret_cast<uintptr_t>(ptr) & 15) == 0;
291+
}
292+
293+
__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) {
294+
return (reinterpret_cast<uintptr_t>(ptr) & 31) == 0;
295+
}
296+
297+
// ============================================================
298+
// Packed type conversion and arithmetic
299+
// ============================================================
300+
301+
template <typename packed_t>
302+
__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) {
303+
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
304+
return __bfloat1622float2(val);
305+
} else if constexpr (std::is_same_v<packed_t, __half2>) {
306+
return __half22float2(val);
307+
} else if constexpr (std::is_same_v<packed_t, float2>) {
308+
return float2(val);
309+
}
310+
}
311+
312+
template <typename packed_t>
313+
__device__ __forceinline__ packed_t cast_to_packed(const float2& val) {
314+
if constexpr (std::is_same_v<packed_t, __nv_bfloat162>) {
315+
return __float22bfloat162_rn(val);
316+
} else if constexpr (std::is_same_v<packed_t, __half2>) {
317+
return __float22half2_rn(val);
318+
} else if constexpr (std::is_same_v<packed_t, float2>) {
319+
return float2(val);
320+
}
321+
}
322+
323+
template <typename packed_t>
324+
__device__ __forceinline__ packed_t packed_mul(const packed_t& x,
325+
const packed_t& y) {
326+
if constexpr (std::is_same_v<packed_t, __nv_bfloat162> ||
327+
std::is_same_v<packed_t, __half2>) {
328+
return __hmul2(x, y);
329+
} else if constexpr (std::is_same_v<packed_t, float2>) {
330+
return make_float2(x.x * y.x, x.y * y.y);
331+
}
332+
}
333+
334+
} // namespace vllm

0 commit comments

Comments
 (0)