Skip to content

Commit c715cba

Browse files
committed
Optimize embedding kernel with vectorized memory access and __ldg
- Add vectorized memory access using float4/float2, half2, and bfloat162 - Use __ldg instruction for read-only weight and indices access - Add memory alignment checks to enable vectorized paths - Add __restrict__ keywords for better compiler optimization - Implement dynamic block size selection based on embedding_dim
1 parent 3b0680e commit c715cba

File tree

2 files changed

+163
-25
lines changed

2 files changed

+163
-25
lines changed

src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh

Lines changed: 145 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,127 @@
22
#define __EMBEDDING_CUDA_KERNEL_CUH__
33

44
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
5+
#include <cuda_fp16.h>
56
#include <cuda_runtime.h>
7+
#include <type_traits>
68

79
namespace op::embedding::nvidia {
810

11+
// Helper function to check memory alignment
12+
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
13+
// Use size_t for pointer arithmetic in device code (more compatible)
14+
return (reinterpret_cast<size_t>(ptr) % alignment == 0);
15+
}
16+
17+
// Vectorized copy for float type using float4
18+
template <typename IndexType>
19+
__forceinline__ __device__ void copyVectorizedFloat4(
20+
float *__restrict__ dst,
21+
const float *__restrict__ src,
22+
size_t embedding_dim) {
23+
// Use float4 for vectorized access (16 bytes, 4 floats)
24+
const float4 *src_vec = reinterpret_cast<const float4 *>(src);
25+
float4 *dst_vec = reinterpret_cast<float4 *>(dst);
26+
size_t vec_count = embedding_dim / 4;
27+
28+
// Vectorized copy using __ldg for read-only weight
29+
for (size_t i = 0; i < vec_count; ++i) {
30+
dst_vec[i] = __ldg(&src_vec[i]);
31+
}
32+
33+
// Copy remaining elements
34+
size_t remaining = embedding_dim % 4;
35+
if (remaining > 0) {
36+
size_t offset = vec_count * 4;
37+
for (size_t i = 0; i < remaining; ++i) {
38+
dst[offset + i] = __ldg(&src[offset + i]);
39+
}
40+
}
41+
}
42+
43+
// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes)
44+
template <typename IndexType>
45+
__forceinline__ __device__ void copyVectorizedFloat2(
46+
float *__restrict__ dst,
47+
const float *__restrict__ src,
48+
size_t embedding_dim) {
49+
// Use float2 for vectorized access (8 bytes, 2 floats)
50+
const float2 *src_vec = reinterpret_cast<const float2 *>(src);
51+
float2 *dst_vec = reinterpret_cast<float2 *>(dst);
52+
size_t vec_count = embedding_dim / 2;
53+
54+
// Vectorized copy using __ldg for read-only weight
55+
for (size_t i = 0; i < vec_count; ++i) {
56+
dst_vec[i] = __ldg(&src_vec[i]);
57+
}
58+
59+
// Copy remaining element if odd
60+
if (embedding_dim % 2 != 0) {
61+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
62+
}
63+
}
64+
65+
// Vectorized copy for half type using half2
66+
template <typename IndexType>
67+
__forceinline__ __device__ void copyVectorizedHalf2(
68+
half *__restrict__ dst,
69+
const half *__restrict__ src,
70+
size_t embedding_dim) {
71+
// Use half2 for vectorized access (4 bytes, 2 halfs)
72+
const half2 *src_vec = reinterpret_cast<const half2 *>(src);
73+
half2 *dst_vec = reinterpret_cast<half2 *>(dst);
74+
size_t vec_count = embedding_dim / 2;
75+
76+
// Vectorized copy using __ldg for read-only weight
77+
for (size_t i = 0; i < vec_count; ++i) {
78+
dst_vec[i] = __ldg(&src_vec[i]);
79+
}
80+
81+
// Copy remaining element if odd
82+
if (embedding_dim % 2 != 0) {
83+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
84+
}
85+
}
86+
87+
// Vectorized copy for bfloat16 type using bfloat162
88+
template <typename IndexType>
89+
__forceinline__ __device__ void copyVectorizedBFloat162(
90+
cuda_bfloat16 *__restrict__ dst,
91+
const cuda_bfloat16 *__restrict__ src,
92+
size_t embedding_dim) {
93+
// Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s)
94+
const cuda_bfloat162 *src_vec = reinterpret_cast<const cuda_bfloat162 *>(src);
95+
cuda_bfloat162 *dst_vec = reinterpret_cast<cuda_bfloat162 *>(dst);
96+
size_t vec_count = embedding_dim / 2;
97+
98+
// Vectorized copy using __ldg for read-only weight
99+
for (size_t i = 0; i < vec_count; ++i) {
100+
dst_vec[i] = __ldg(&src_vec[i]);
101+
}
102+
103+
// Copy remaining element if odd
104+
if (embedding_dim % 2 != 0) {
105+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
106+
}
107+
}
108+
109+
// Scalar copy fallback with __ldg optimization
110+
template <typename T, typename IndexType>
111+
__forceinline__ __device__ void copyScalar(
112+
T *__restrict__ dst,
113+
const T *__restrict__ src,
114+
size_t embedding_dim) {
115+
// Scalar copy with __ldg for read-only weight
116+
for (size_t i = 0; i < embedding_dim; ++i) {
117+
dst[i] = __ldg(&src[i]);
118+
}
119+
}
120+
9121
template <typename T, typename IndexType>
10122
INFINIOP_CUDA_KERNEL embeddingKernel(
11-
T *output,
12-
const IndexType *indices,
13-
const T *weight,
123+
T *__restrict__ output,
124+
const IndexType *__restrict__ indices,
125+
const T *__restrict__ weight,
14126
size_t num_indices,
15127
size_t embedding_dim,
16128
size_t vocab_size) {
@@ -19,27 +131,43 @@ INFINIOP_CUDA_KERNEL embeddingKernel(
19131

20132
if (idx < num_indices) {
21133
// Get the index value
22-
IndexType index_val = indices[idx];
134+
IndexType index_val = __ldg(&indices[idx]);
23135

24136
// Bounds check - handle negative indices gracefully
25137
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
26138
// Copy embedding vector from weight to output
27139
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
28140
T *dst = output + idx * embedding_dim;
29141

30-
// Copy embedding_dim elements
31-
// Use vectorized copy for better performance when possible
32-
size_t i = 0;
33-
// Copy in chunks of 4 for better memory bandwidth utilization
34-
for (; i + 4 <= embedding_dim; i += 4) {
35-
dst[i] = src[i];
36-
dst[i + 1] = src[i + 1];
37-
dst[i + 2] = src[i + 2];
38-
dst[i + 3] = src[i + 3];
39-
}
40-
// Copy remaining elements
41-
for (; i < embedding_dim; ++i) {
42-
dst[i] = src[i];
142+
// Choose optimal copy strategy based on type and alignment
143+
if constexpr (std::is_same_v<T, float>) {
144+
// Check alignment for float4 (16 bytes)
145+
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
146+
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
147+
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
148+
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
149+
// Try float2 if not aligned to 16 bytes
150+
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
151+
} else {
152+
copyScalar<T, IndexType>(dst, src, embedding_dim);
153+
}
154+
} else if constexpr (std::is_same_v<T, half>) {
155+
// Use half2 for vectorized access
156+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
157+
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
158+
} else {
159+
copyScalar<T, IndexType>(dst, src, embedding_dim);
160+
}
161+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
162+
// Use bfloat162 for vectorized access
163+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
164+
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
165+
} else {
166+
copyScalar<T, IndexType>(dst, src, embedding_dim);
167+
}
168+
} else {
169+
// Fallback to scalar copy with __ldg
170+
copyScalar<T, IndexType>(dst, src, embedding_dim);
43171
}
44172
}
45173
}

src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,31 +79,41 @@ infiniStatus_t Descriptor::calculate(
7979
}
8080

8181
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
82-
constexpr size_t BLOCK_SIZE = 256;
83-
size_t grid_size = (_num_indices + BLOCK_SIZE - 1) / BLOCK_SIZE;
82+
83+
// Dynamic block size optimization based on embedding_dim
84+
// Smaller embedding_dim benefits from larger block size (better occupancy)
85+
// Larger embedding_dim benefits from smaller block size (more registers per thread)
86+
size_t block_size = 256; // Default
87+
if (_embedding_dim <= 64) {
88+
block_size = 512; // Small embedding_dim: use larger block for better occupancy
89+
} else if (_embedding_dim >= 1024) {
90+
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
91+
}
92+
93+
size_t grid_size = (_num_indices + block_size - 1) / block_size;
8494

8595
// Launch kernel based on dtypes
8696
if (_input_dtype == INFINI_DTYPE_I32) {
8797
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
8898

8999
if (_weight_dtype == INFINI_DTYPE_F32) {
90-
embeddingKernel<float, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
100+
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
91101
reinterpret_cast<float *>(output),
92102
indices_ptr,
93103
reinterpret_cast<const float *>(weight),
94104
_num_indices,
95105
_embedding_dim,
96106
_vocab_size);
97107
} else if (_weight_dtype == INFINI_DTYPE_F16) {
98-
embeddingKernel<half, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
108+
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
99109
reinterpret_cast<half *>(output),
100110
indices_ptr,
101111
reinterpret_cast<const half *>(weight),
102112
_num_indices,
103113
_embedding_dim,
104114
_vocab_size);
105115
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
106-
embeddingKernel<cuda_bfloat16, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
116+
embeddingKernel<cuda_bfloat16, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
107117
reinterpret_cast<cuda_bfloat16 *>(output),
108118
indices_ptr,
109119
reinterpret_cast<const cuda_bfloat16 *>(weight),
@@ -117,23 +127,23 @@ infiniStatus_t Descriptor::calculate(
117127
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
118128

119129
if (_weight_dtype == INFINI_DTYPE_F32) {
120-
embeddingKernel<float, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
130+
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
121131
reinterpret_cast<float *>(output),
122132
indices_ptr,
123133
reinterpret_cast<const float *>(weight),
124134
_num_indices,
125135
_embedding_dim,
126136
_vocab_size);
127137
} else if (_weight_dtype == INFINI_DTYPE_F16) {
128-
embeddingKernel<half, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
138+
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
129139
reinterpret_cast<half *>(output),
130140
indices_ptr,
131141
reinterpret_cast<const half *>(weight),
132142
_num_indices,
133143
_embedding_dim,
134144
_vocab_size);
135145
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
136-
embeddingKernel<cuda_bfloat16, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
146+
embeddingKernel<cuda_bfloat16, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
137147
reinterpret_cast<cuda_bfloat16 *>(output),
138148
indices_ptr,
139149
reinterpret_cast<const cuda_bfloat16 *>(weight),

0 commit comments

Comments
 (0)