22#define __EMBEDDING_CUDA_KERNEL_CUH__
33
44#include " ../../../devices/nvidia/nvidia_kernel_common.cuh"
5+ #include < cuda_fp16.h>
56#include < cuda_runtime.h>
7+ #include < type_traits>
68
79namespace op ::embedding::nvidia {
810
11+ // Helper function to check memory alignment
12+ __forceinline__ __device__ bool is_aligned (const void *ptr, size_t alignment) {
13+ // Use size_t for pointer arithmetic in device code (more compatible)
14+ return (reinterpret_cast <size_t >(ptr) % alignment == 0 );
15+ }
16+
17+ // Vectorized copy for float type using float4
18+ template <typename IndexType>
19+ __forceinline__ __device__ void copyVectorizedFloat4 (
20+ float *__restrict__ dst,
21+ const float *__restrict__ src,
22+ size_t embedding_dim) {
23+ // Use float4 for vectorized access (16 bytes, 4 floats)
24+ const float4 *src_vec = reinterpret_cast <const float4 *>(src);
25+ float4 *dst_vec = reinterpret_cast <float4 *>(dst);
26+ size_t vec_count = embedding_dim / 4 ;
27+
28+ // Vectorized copy using __ldg for read-only weight
29+ for (size_t i = 0 ; i < vec_count; ++i) {
30+ dst_vec[i] = __ldg (&src_vec[i]);
31+ }
32+
33+ // Copy remaining elements
34+ size_t remaining = embedding_dim % 4 ;
35+ if (remaining > 0 ) {
36+ size_t offset = vec_count * 4 ;
37+ for (size_t i = 0 ; i < remaining; ++i) {
38+ dst[offset + i] = __ldg (&src[offset + i]);
39+ }
40+ }
41+ }
42+
43+ // Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes)
44+ template <typename IndexType>
45+ __forceinline__ __device__ void copyVectorizedFloat2 (
46+ float *__restrict__ dst,
47+ const float *__restrict__ src,
48+ size_t embedding_dim) {
49+ // Use float2 for vectorized access (8 bytes, 2 floats)
50+ const float2 *src_vec = reinterpret_cast <const float2 *>(src);
51+ float2 *dst_vec = reinterpret_cast <float2 *>(dst);
52+ size_t vec_count = embedding_dim / 2 ;
53+
54+ // Vectorized copy using __ldg for read-only weight
55+ for (size_t i = 0 ; i < vec_count; ++i) {
56+ dst_vec[i] = __ldg (&src_vec[i]);
57+ }
58+
59+ // Copy remaining element if odd
60+ if (embedding_dim % 2 != 0 ) {
61+ dst[embedding_dim - 1 ] = __ldg (&src[embedding_dim - 1 ]);
62+ }
63+ }
64+
65+ // Vectorized copy for half type using half2
66+ template <typename IndexType>
67+ __forceinline__ __device__ void copyVectorizedHalf2 (
68+ half *__restrict__ dst,
69+ const half *__restrict__ src,
70+ size_t embedding_dim) {
71+ // Use half2 for vectorized access (4 bytes, 2 halfs)
72+ const half2 *src_vec = reinterpret_cast <const half2 *>(src);
73+ half2 *dst_vec = reinterpret_cast <half2 *>(dst);
74+ size_t vec_count = embedding_dim / 2 ;
75+
76+ // Vectorized copy using __ldg for read-only weight
77+ for (size_t i = 0 ; i < vec_count; ++i) {
78+ dst_vec[i] = __ldg (&src_vec[i]);
79+ }
80+
81+ // Copy remaining element if odd
82+ if (embedding_dim % 2 != 0 ) {
83+ dst[embedding_dim - 1 ] = __ldg (&src[embedding_dim - 1 ]);
84+ }
85+ }
86+
87+ // Vectorized copy for bfloat16 type using bfloat162
88+ template <typename IndexType>
89+ __forceinline__ __device__ void copyVectorizedBFloat162 (
90+ cuda_bfloat16 *__restrict__ dst,
91+ const cuda_bfloat16 *__restrict__ src,
92+ size_t embedding_dim) {
93+ // Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s)
94+ const cuda_bfloat162 *src_vec = reinterpret_cast <const cuda_bfloat162 *>(src);
95+ cuda_bfloat162 *dst_vec = reinterpret_cast <cuda_bfloat162 *>(dst);
96+ size_t vec_count = embedding_dim / 2 ;
97+
98+ // Vectorized copy using __ldg for read-only weight
99+ for (size_t i = 0 ; i < vec_count; ++i) {
100+ dst_vec[i] = __ldg (&src_vec[i]);
101+ }
102+
103+ // Copy remaining element if odd
104+ if (embedding_dim % 2 != 0 ) {
105+ dst[embedding_dim - 1 ] = __ldg (&src[embedding_dim - 1 ]);
106+ }
107+ }
108+
109+ // Scalar copy fallback with __ldg optimization
110+ template <typename T, typename IndexType>
111+ __forceinline__ __device__ void copyScalar (
112+ T *__restrict__ dst,
113+ const T *__restrict__ src,
114+ size_t embedding_dim) {
115+ // Scalar copy with __ldg for read-only weight
116+ for (size_t i = 0 ; i < embedding_dim; ++i) {
117+ dst[i] = __ldg (&src[i]);
118+ }
119+ }
120+
9121template <typename T, typename IndexType>
10122INFINIOP_CUDA_KERNEL embeddingKernel (
11- T *output,
12- const IndexType *indices,
13- const T *weight,
123+ T *__restrict__ output,
124+ const IndexType *__restrict__ indices,
125+ const T *__restrict__ weight,
14126 size_t num_indices,
15127 size_t embedding_dim,
16128 size_t vocab_size) {
@@ -19,27 +131,43 @@ INFINIOP_CUDA_KERNEL embeddingKernel(
19131
20132 if (idx < num_indices) {
21133 // Get the index value
22- IndexType index_val = indices[idx];
134+ IndexType index_val = __ldg (& indices[idx]) ;
23135
24136 // Bounds check - handle negative indices gracefully
25137 if (index_val >= 0 && static_cast <size_t >(index_val) < vocab_size) {
26138 // Copy embedding vector from weight to output
27139 const T *src = weight + static_cast <size_t >(index_val) * embedding_dim;
28140 T *dst = output + idx * embedding_dim;
29141
30- // Copy embedding_dim elements
31- // Use vectorized copy for better performance when possible
32- size_t i = 0 ;
33- // Copy in chunks of 4 for better memory bandwidth utilization
34- for (; i + 4 <= embedding_dim; i += 4 ) {
35- dst[i] = src[i];
36- dst[i + 1 ] = src[i + 1 ];
37- dst[i + 2 ] = src[i + 2 ];
38- dst[i + 3 ] = src[i + 3 ];
39- }
40- // Copy remaining elements
41- for (; i < embedding_dim; ++i) {
42- dst[i] = src[i];
142+ // Choose optimal copy strategy based on type and alignment
143+ if constexpr (std::is_same_v<T, float >) {
144+ // Check alignment for float4 (16 bytes)
145+ bool aligned_16 = is_aligned (src, 16 ) && is_aligned (dst, 16 );
146+ if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0 ) {
147+ copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
148+ } else if (embedding_dim >= 2 && embedding_dim % 2 == 0 ) {
149+ // Try float2 if not aligned to 16 bytes
150+ copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
151+ } else {
152+ copyScalar<T, IndexType>(dst, src, embedding_dim);
153+ }
154+ } else if constexpr (std::is_same_v<T, half>) {
155+ // Use half2 for vectorized access
156+ if (embedding_dim >= 2 && embedding_dim % 2 == 0 ) {
157+ copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
158+ } else {
159+ copyScalar<T, IndexType>(dst, src, embedding_dim);
160+ }
161+ } else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
162+ // Use bfloat162 for vectorized access
163+ if (embedding_dim >= 2 && embedding_dim % 2 == 0 ) {
164+ copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
165+ } else {
166+ copyScalar<T, IndexType>(dst, src, embedding_dim);
167+ }
168+ } else {
169+ // Fallback to scalar copy with __ldg
170+ copyScalar<T, IndexType>(dst, src, embedding_dim);
43171 }
44172 }
45173 }
0 commit comments