Skip to content

Commit 246fd9c

Browse files
committed
Optimize embedding kernel with vectorized memory access and __ldg
- Add vectorized memory access using float4/float2, half2, and bfloat162 - Use __ldg instruction for read-only weight and indices access - Add memory alignment checks to enable vectorized paths - Add __restrict__ keywords for better compiler optimization - Implement dynamic block size selection based on embedding_dim
1 parent 3b0680e commit 246fd9c

File tree

2 files changed

+184
-45
lines changed

2 files changed

+184
-45
lines changed

src/infiniop/ops/embedding/nvidia/embedding_kernel.cuh

Lines changed: 148 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,43 +3,171 @@
33

44
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
55
#include <cuda_runtime.h>
6+
#include <cuda_fp16.h>
7+
#include <type_traits>
68

79
namespace op::embedding::nvidia {
810

11+
// Helper function to check memory alignment
12+
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
13+
// Use size_t for pointer arithmetic in device code (more compatible)
14+
return (reinterpret_cast<size_t>(ptr) % alignment == 0);
15+
}
16+
17+
// Vectorized copy for float type using float4
18+
template <typename IndexType>
19+
__forceinline__ __device__ void copyVectorizedFloat4(
20+
float *__restrict__ dst,
21+
const float *__restrict__ src,
22+
size_t embedding_dim) {
23+
// Use float4 for vectorized access (16 bytes, 4 floats)
24+
const float4 *src_vec = reinterpret_cast<const float4 *>(src);
25+
float4 *dst_vec = reinterpret_cast<float4 *>(dst);
26+
size_t vec_count = embedding_dim / 4;
27+
28+
// Vectorized copy using __ldg for read-only weight
29+
for (size_t i = 0; i < vec_count; ++i) {
30+
dst_vec[i] = __ldg(&src_vec[i]);
31+
}
32+
33+
// Copy remaining elements
34+
size_t remaining = embedding_dim % 4;
35+
if (remaining > 0) {
36+
size_t offset = vec_count * 4;
37+
for (size_t i = 0; i < remaining; ++i) {
38+
dst[offset + i] = __ldg(&src[offset + i]);
39+
}
40+
}
41+
}
42+
43+
// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes)
44+
template <typename IndexType>
45+
__forceinline__ __device__ void copyVectorizedFloat2(
46+
float *__restrict__ dst,
47+
const float *__restrict__ src,
48+
size_t embedding_dim) {
49+
// Use float2 for vectorized access (8 bytes, 2 floats)
50+
const float2 *src_vec = reinterpret_cast<const float2 *>(src);
51+
float2 *dst_vec = reinterpret_cast<float2 *>(dst);
52+
size_t vec_count = embedding_dim / 2;
53+
54+
// Vectorized copy using __ldg for read-only weight
55+
for (size_t i = 0; i < vec_count; ++i) {
56+
dst_vec[i] = __ldg(&src_vec[i]);
57+
}
58+
59+
// Copy remaining element if odd
60+
if (embedding_dim % 2 != 0) {
61+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
62+
}
63+
}
64+
65+
// Vectorized copy for half type using half2
66+
template <typename IndexType>
67+
__forceinline__ __device__ void copyVectorizedHalf2(
68+
half *__restrict__ dst,
69+
const half *__restrict__ src,
70+
size_t embedding_dim) {
71+
// Use half2 for vectorized access (4 bytes, 2 halfs)
72+
const half2 *src_vec = reinterpret_cast<const half2 *>(src);
73+
half2 *dst_vec = reinterpret_cast<half2 *>(dst);
74+
size_t vec_count = embedding_dim / 2;
75+
76+
// Vectorized copy using __ldg for read-only weight
77+
for (size_t i = 0; i < vec_count; ++i) {
78+
dst_vec[i] = __ldg(&src_vec[i]);
79+
}
80+
81+
// Copy remaining element if odd
82+
if (embedding_dim % 2 != 0) {
83+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
84+
}
85+
}
86+
87+
// Vectorized copy for bfloat16 type using bfloat162
88+
template <typename IndexType>
89+
__forceinline__ __device__ void copyVectorizedBFloat162(
90+
cuda_bfloat16 *__restrict__ dst,
91+
const cuda_bfloat16 *__restrict__ src,
92+
size_t embedding_dim) {
93+
// Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s)
94+
const cuda_bfloat162 *src_vec = reinterpret_cast<const cuda_bfloat162 *>(src);
95+
cuda_bfloat162 *dst_vec = reinterpret_cast<cuda_bfloat162 *>(dst);
96+
size_t vec_count = embedding_dim / 2;
97+
98+
// Vectorized copy using __ldg for read-only weight
99+
for (size_t i = 0; i < vec_count; ++i) {
100+
dst_vec[i] = __ldg(&src_vec[i]);
101+
}
102+
103+
// Copy remaining element if odd
104+
if (embedding_dim % 2 != 0) {
105+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
106+
}
107+
}
108+
109+
// Scalar copy fallback with __ldg optimization
110+
template <typename T, typename IndexType>
111+
__forceinline__ __device__ void copyScalar(
112+
T *__restrict__ dst,
113+
const T *__restrict__ src,
114+
size_t embedding_dim) {
115+
// Scalar copy with __ldg for read-only weight
116+
for (size_t i = 0; i < embedding_dim; ++i) {
117+
dst[i] = __ldg(&src[i]);
118+
}
119+
}
120+
9121
template <typename T, typename IndexType>
10122
INFINIOP_CUDA_KERNEL embeddingKernel(
11-
T *output,
12-
const IndexType *indices,
13-
const T *weight,
123+
T *__restrict__ output,
124+
const IndexType *__restrict__ indices,
125+
const T *__restrict__ weight,
14126
size_t num_indices,
15127
size_t embedding_dim,
16128
size_t vocab_size) {
17129
// Calculate global thread index
18130
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
19-
131+
20132
if (idx < num_indices) {
21133
// Get the index value
22-
IndexType index_val = indices[idx];
23-
134+
IndexType index_val = __ldg(&indices[idx]);
135+
24136
// Bounds check - handle negative indices gracefully
25137
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
26138
// Copy embedding vector from weight to output
27139
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
28140
T *dst = output + idx * embedding_dim;
29-
30-
// Copy embedding_dim elements
31-
// Use vectorized copy for better performance when possible
32-
size_t i = 0;
33-
// Copy in chunks of 4 for better memory bandwidth utilization
34-
for (; i + 4 <= embedding_dim; i += 4) {
35-
dst[i] = src[i];
36-
dst[i + 1] = src[i + 1];
37-
dst[i + 2] = src[i + 2];
38-
dst[i + 3] = src[i + 3];
39-
}
40-
// Copy remaining elements
41-
for (; i < embedding_dim; ++i) {
42-
dst[i] = src[i];
141+
142+
// Choose optimal copy strategy based on type and alignment
143+
if constexpr (std::is_same_v<T, float>) {
144+
// Check alignment for float4 (16 bytes)
145+
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
146+
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
147+
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
148+
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
149+
// Try float2 if not aligned to 16 bytes
150+
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
151+
} else {
152+
copyScalar<T, IndexType>(dst, src, embedding_dim);
153+
}
154+
} else if constexpr (std::is_same_v<T, half>) {
155+
// Use half2 for vectorized access
156+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
157+
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
158+
} else {
159+
copyScalar<T, IndexType>(dst, src, embedding_dim);
160+
}
161+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
162+
// Use bfloat162 for vectorized access
163+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
164+
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
165+
} else {
166+
copyScalar<T, IndexType>(dst, src, embedding_dim);
167+
}
168+
} else {
169+
// Fallback to scalar copy with __ldg
170+
copyScalar<T, IndexType>(dst, src, embedding_dim);
43171
}
44172
}
45173
}

src/infiniop/ops/embedding/nvidia/embedding_nvidia.cu

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#include "../../../../utils.h"
21
#include "../../../devices/nvidia/nvidia_common.cuh"
32
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
43
#include "../../../tensor.h"
4+
#include "../../../../utils.h"
55
#include "embedding_kernel.cuh"
66
#include "embedding_nvidia.cuh"
77
#include <cuda_runtime.h>
@@ -25,36 +25,37 @@ infiniStatus_t Descriptor::create(
2525

2626
auto input_shape = input_desc->shape();
2727
auto weight_shape = weight_desc->shape();
28-
28+
2929
// Validate shapes
3030
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
3131
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
32-
32+
3333
// Check output shape matches input shape + embedding_dim
3434
auto output_shape = output_desc->shape();
3535
size_t embedding_dim = weight_shape[1];
3636
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
37-
37+
3838
for (size_t i = 0; i < input_shape.size(); ++i) {
3939
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
4040
}
41-
41+
4242
// Validate dtypes
4343
auto input_dtype = input_desc->dtype();
4444
auto weight_dtype = weight_desc->dtype();
4545
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
4646
INFINI_STATUS_BAD_TENSOR_DTYPE);
47-
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 || weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
47+
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 ||
48+
weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
4849
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
49-
50+
5051
// Calculate number of indices (supporting batch dimension)
5152
size_t num_indices = 1;
5253
for (auto dim : input_shape) {
5354
num_indices *= dim;
5455
}
55-
56+
5657
size_t vocab_size = weight_shape[0];
57-
58+
5859
*desc_ptr = new Descriptor(
5960
num_indices,
6061
embedding_dim,
@@ -64,7 +65,7 @@ infiniStatus_t Descriptor::create(
6465
new Opaque{reinterpret_cast<device::nvidia::Handle *>(handle)->internal()},
6566
handle->device,
6667
handle->device_id);
67-
68+
6869
return INFINI_STATUS_SUCCESS;
6970
}
7071

@@ -73,37 +74,47 @@ infiniStatus_t Descriptor::calculate(
7374
const void *input,
7475
const void *weight,
7576
void *stream) const {
76-
77+
7778
if (_num_indices == 0) {
7879
return INFINI_STATUS_SUCCESS;
7980
}
80-
81+
8182
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);
82-
constexpr size_t BLOCK_SIZE = 256;
83-
size_t grid_size = (_num_indices + BLOCK_SIZE - 1) / BLOCK_SIZE;
84-
83+
84+
// Dynamic block size optimization based on embedding_dim
85+
// Smaller embedding_dim benefits from larger block size (better occupancy)
86+
// Larger embedding_dim benefits from smaller block size (more registers per thread)
87+
size_t block_size = 256; // Default
88+
if (_embedding_dim <= 64) {
89+
block_size = 512; // Small embedding_dim: use larger block for better occupancy
90+
} else if (_embedding_dim >= 1024) {
91+
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
92+
}
93+
94+
size_t grid_size = (_num_indices + block_size - 1) / block_size;
95+
8596
// Launch kernel based on dtypes
8697
if (_input_dtype == INFINI_DTYPE_I32) {
8798
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
88-
99+
89100
if (_weight_dtype == INFINI_DTYPE_F32) {
90-
embeddingKernel<float, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
101+
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
91102
reinterpret_cast<float *>(output),
92103
indices_ptr,
93104
reinterpret_cast<const float *>(weight),
94105
_num_indices,
95106
_embedding_dim,
96107
_vocab_size);
97108
} else if (_weight_dtype == INFINI_DTYPE_F16) {
98-
embeddingKernel<half, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
109+
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
99110
reinterpret_cast<half *>(output),
100111
indices_ptr,
101112
reinterpret_cast<const half *>(weight),
102113
_num_indices,
103114
_embedding_dim,
104115
_vocab_size);
105116
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
106-
embeddingKernel<cuda_bfloat16, int32_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
117+
embeddingKernel<cuda_bfloat16, int32_t><<<grid_size, block_size, 0, cuda_stream>>>(
107118
reinterpret_cast<cuda_bfloat16 *>(output),
108119
indices_ptr,
109120
reinterpret_cast<const cuda_bfloat16 *>(weight),
@@ -115,25 +126,25 @@ infiniStatus_t Descriptor::calculate(
115126
}
116127
} else if (_input_dtype == INFINI_DTYPE_I64) {
117128
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
118-
129+
119130
if (_weight_dtype == INFINI_DTYPE_F32) {
120-
embeddingKernel<float, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
131+
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
121132
reinterpret_cast<float *>(output),
122133
indices_ptr,
123134
reinterpret_cast<const float *>(weight),
124135
_num_indices,
125136
_embedding_dim,
126137
_vocab_size);
127138
} else if (_weight_dtype == INFINI_DTYPE_F16) {
128-
embeddingKernel<half, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
139+
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
129140
reinterpret_cast<half *>(output),
130141
indices_ptr,
131142
reinterpret_cast<const half *>(weight),
132143
_num_indices,
133144
_embedding_dim,
134145
_vocab_size);
135146
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
136-
embeddingKernel<cuda_bfloat16, int64_t><<<grid_size, BLOCK_SIZE, 0, cuda_stream>>>(
147+
embeddingKernel<cuda_bfloat16, int64_t><<<grid_size, block_size, 0, cuda_stream>>>(
137148
reinterpret_cast<cuda_bfloat16 *>(output),
138149
indices_ptr,
139150
reinterpret_cast<const cuda_bfloat16 *>(weight),
@@ -146,13 +157,13 @@ infiniStatus_t Descriptor::calculate(
146157
} else {
147158
return INFINI_STATUS_BAD_TENSOR_DTYPE;
148159
}
149-
160+
150161
// Check for kernel launch errors
151162
cudaError_t err = cudaGetLastError();
152163
if (err != cudaSuccess) {
153164
return INFINI_STATUS_INTERNAL_ERROR;
154165
}
155-
166+
156167
return INFINI_STATUS_SUCCESS;
157168
}
158169

0 commit comments

Comments
 (0)