Skip to content

Commit bea5068

Browse files
committed
issue/900 - support embedding on iluvatar, metax, and moore
1 parent 7b88379 commit bea5068

File tree

8 files changed

+655
-55
lines changed

8 files changed

+655
-55
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
#ifndef __EMBEDDING_CUDA_KERNEL_CUH__
2+
#define __EMBEDDING_CUDA_KERNEL_CUH__
3+
4+
#include <type_traits>
5+
6+
// Helper function to check memory alignment
7+
__forceinline__ __device__ bool is_aligned(const void *ptr, size_t alignment) {
8+
// Use size_t for pointer arithmetic in device code (more compatible)
9+
return (reinterpret_cast<size_t>(ptr) % alignment == 0);
10+
}
11+
12+
// Vectorized copy for float type using float4
13+
template <typename IndexType>
14+
__forceinline__ __device__ void copyVectorizedFloat4(
15+
float *__restrict__ dst,
16+
const float *__restrict__ src,
17+
size_t embedding_dim) {
18+
// Use float4 for vectorized access (16 bytes, 4 floats)
19+
const float4 *src_vec = reinterpret_cast<const float4 *>(src);
20+
float4 *dst_vec = reinterpret_cast<float4 *>(dst);
21+
size_t vec_count = embedding_dim / 4;
22+
23+
// Vectorized copy using __ldg for read-only weight
24+
for (size_t i = 0; i < vec_count; ++i) {
25+
dst_vec[i] = __ldg(&src_vec[i]);
26+
}
27+
28+
// Copy remaining elements
29+
size_t remaining = embedding_dim % 4;
30+
if (remaining > 0) {
31+
size_t offset = vec_count * 4;
32+
for (size_t i = 0; i < remaining; ++i) {
33+
dst[offset + i] = __ldg(&src[offset + i]);
34+
}
35+
}
36+
}
37+
38+
// Vectorized copy for float type using float2 (fallback when not aligned to 16 bytes)
39+
template <typename IndexType>
40+
__forceinline__ __device__ void copyVectorizedFloat2(
41+
float *__restrict__ dst,
42+
const float *__restrict__ src,
43+
size_t embedding_dim) {
44+
// Use float2 for vectorized access (8 bytes, 2 floats)
45+
const float2 *src_vec = reinterpret_cast<const float2 *>(src);
46+
float2 *dst_vec = reinterpret_cast<float2 *>(dst);
47+
size_t vec_count = embedding_dim / 2;
48+
49+
// Vectorized copy using __ldg for read-only weight
50+
for (size_t i = 0; i < vec_count; ++i) {
51+
dst_vec[i] = __ldg(&src_vec[i]);
52+
}
53+
54+
// Copy remaining element if odd
55+
if (embedding_dim % 2 != 0) {
56+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
57+
}
58+
}
59+
60+
// Vectorized copy for half type using half2
61+
template <typename IndexType>
62+
__forceinline__ __device__ void copyVectorizedHalf2(
63+
half *__restrict__ dst,
64+
const half *__restrict__ src,
65+
size_t embedding_dim) {
66+
// Use half2 for vectorized access (4 bytes, 2 halfs)
67+
const half2 *src_vec = reinterpret_cast<const half2 *>(src);
68+
half2 *dst_vec = reinterpret_cast<half2 *>(dst);
69+
size_t vec_count = embedding_dim / 2;
70+
71+
// Vectorized copy using __ldg for read-only weight
72+
for (size_t i = 0; i < vec_count; ++i) {
73+
dst_vec[i] = __ldg(&src_vec[i]);
74+
}
75+
76+
// Copy remaining element if odd
77+
if (embedding_dim % 2 != 0) {
78+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
79+
}
80+
}
81+
82+
// Vectorized copy for bfloat16 type using bfloat162
83+
template <typename IndexType>
84+
__forceinline__ __device__ void copyVectorizedBFloat162(
85+
cuda_bfloat16 *__restrict__ dst,
86+
const cuda_bfloat16 *__restrict__ src,
87+
size_t embedding_dim) {
88+
// Use bfloat162 for vectorized access (4 bytes, 2 bfloat16s)
89+
const cuda_bfloat162 *src_vec = reinterpret_cast<const cuda_bfloat162 *>(src);
90+
cuda_bfloat162 *dst_vec = reinterpret_cast<cuda_bfloat162 *>(dst);
91+
size_t vec_count = embedding_dim / 2;
92+
93+
// Vectorized copy using __ldg for read-only weight
94+
for (size_t i = 0; i < vec_count; ++i) {
95+
dst_vec[i] = __ldg(&src_vec[i]);
96+
}
97+
98+
// Copy remaining element if odd
99+
if (embedding_dim % 2 != 0) {
100+
dst[embedding_dim - 1] = __ldg(&src[embedding_dim - 1]);
101+
}
102+
}
103+
104+
// Scalar copy fallback with __ldg optimization
105+
template <typename T, typename IndexType>
106+
__forceinline__ __device__ void copyScalar(
107+
T *__restrict__ dst,
108+
const T *__restrict__ src,
109+
size_t embedding_dim) {
110+
// Scalar copy with __ldg for read-only weight
111+
for (size_t i = 0; i < embedding_dim; ++i) {
112+
dst[i] = __ldg(&src[i]);
113+
}
114+
}
115+
116+
#endif // __EMBEDDING_CUDA_KERNEL_CUH__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __EMBEDDING_METAX_H__
2+
#define __EMBEDDING_METAX_H__
3+
4+
#include "../embedding.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif // __EMBEDDING_METAX_H__
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
#include "../../../../utils.h"
2+
#include "../../../devices/metax/metax_common.h"
3+
#include "../../../devices/metax/metax_kernel_common.h"
4+
#include "../../../tensor.h"
5+
#include "../cuda/embedding_kernel.cuh"
6+
#include "embedding_metax.cuh"
7+
8+
template <typename T, typename IndexType>
9+
INFINIOP_METAX_KERNEL embeddingKernel(
10+
T *__restrict__ output,
11+
const IndexType *__restrict__ indices,
12+
const T *__restrict__ weight,
13+
size_t num_indices,
14+
size_t embedding_dim,
15+
size_t vocab_size) {
16+
// Calculate global thread index
17+
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
18+
19+
if (idx < num_indices) {
20+
// Get the index value
21+
IndexType index_val = __ldg(&indices[idx]);
22+
23+
// Bounds check - handle negative indices gracefully
24+
if (index_val >= 0 && static_cast<size_t>(index_val) < vocab_size) {
25+
// Copy embedding vector from weight to output
26+
const T *src = weight + static_cast<size_t>(index_val) * embedding_dim;
27+
T *dst = output + idx * embedding_dim;
28+
29+
// Choose optimal copy strategy based on type and alignment
30+
if constexpr (std::is_same_v<T, float>) {
31+
// Check alignment for float4 (16 bytes)
32+
bool aligned_16 = is_aligned(src, 16) && is_aligned(dst, 16);
33+
if (aligned_16 && embedding_dim >= 4 && embedding_dim % 4 == 0) {
34+
copyVectorizedFloat4<IndexType>(dst, src, embedding_dim);
35+
} else if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
36+
// Try float2 if not aligned to 16 bytes
37+
copyVectorizedFloat2<IndexType>(dst, src, embedding_dim);
38+
} else {
39+
copyScalar<T, IndexType>(dst, src, embedding_dim);
40+
}
41+
} else if constexpr (std::is_same_v<T, half>) {
42+
// Use half2 for vectorized access
43+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
44+
copyVectorizedHalf2<IndexType>(dst, src, embedding_dim);
45+
} else {
46+
copyScalar<T, IndexType>(dst, src, embedding_dim);
47+
}
48+
} else if constexpr (std::is_same_v<T, cuda_bfloat16>) {
49+
// Use bfloat162 for vectorized access
50+
if (embedding_dim >= 2 && embedding_dim % 2 == 0) {
51+
copyVectorizedBFloat162<IndexType>(dst, src, embedding_dim);
52+
} else {
53+
copyScalar<T, IndexType>(dst, src, embedding_dim);
54+
}
55+
} else {
56+
// Fallback to scalar copy with __ldg
57+
copyScalar<T, IndexType>(dst, src, embedding_dim);
58+
}
59+
}
60+
}
61+
}
62+
63+
namespace op::embedding::metax {
64+
65+
struct Descriptor::Opaque {
66+
std::shared_ptr<device::metax::Handle::Internal> internal;
67+
};
68+
69+
Descriptor::~Descriptor() {
70+
delete _opaque;
71+
}
72+
73+
infiniStatus_t Descriptor::create(
74+
infiniopHandle_t handle,
75+
Descriptor **desc_ptr,
76+
infiniopTensorDescriptor_t output_desc,
77+
infiniopTensorDescriptor_t input_desc,
78+
infiniopTensorDescriptor_t weight_desc) {
79+
80+
auto input_shape = input_desc->shape();
81+
auto weight_shape = weight_desc->shape();
82+
83+
// Validate shapes
84+
CHECK_OR_RETURN(weight_shape.size() == 2, INFINI_STATUS_BAD_TENSOR_SHAPE);
85+
CHECK_OR_RETURN(output_desc->shape().size() == input_shape.size() + 1, INFINI_STATUS_BAD_TENSOR_SHAPE);
86+
87+
// Check output shape matches input shape + embedding_dim
88+
auto output_shape = output_desc->shape();
89+
size_t embedding_dim = weight_shape[1];
90+
CHECK_OR_RETURN(output_shape.back() == embedding_dim, INFINI_STATUS_BAD_TENSOR_SHAPE);
91+
92+
for (size_t i = 0; i < input_shape.size(); ++i) {
93+
CHECK_OR_RETURN(output_shape[i] == input_shape[i], INFINI_STATUS_BAD_TENSOR_SHAPE);
94+
}
95+
96+
// Validate dtypes
97+
auto input_dtype = input_desc->dtype();
98+
auto weight_dtype = weight_desc->dtype();
99+
CHECK_OR_RETURN(input_dtype == INFINI_DTYPE_I32 || input_dtype == INFINI_DTYPE_I64,
100+
INFINI_STATUS_BAD_TENSOR_DTYPE);
101+
CHECK_OR_RETURN(weight_dtype == INFINI_DTYPE_F32 || weight_dtype == INFINI_DTYPE_F16 ||
102+
weight_dtype == INFINI_DTYPE_BF16, INFINI_STATUS_BAD_TENSOR_DTYPE);
103+
CHECK_OR_RETURN(output_desc->dtype() == weight_dtype, INFINI_STATUS_BAD_TENSOR_DTYPE);
104+
105+
// Calculate number of indices (supporting batch dimension)
106+
size_t num_indices = 1;
107+
for (auto dim : input_shape) {
108+
num_indices *= dim;
109+
}
110+
111+
size_t vocab_size = weight_shape[0];
112+
113+
*desc_ptr = new Descriptor(
114+
num_indices,
115+
embedding_dim,
116+
vocab_size,
117+
input_dtype,
118+
weight_dtype,
119+
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
120+
handle->device,
121+
handle->device_id);
122+
123+
return INFINI_STATUS_SUCCESS;
124+
}
125+
126+
infiniStatus_t Descriptor::calculate(
127+
void *output,
128+
const void *input,
129+
const void *weight,
130+
void *stream) const {
131+
132+
if (_num_indices == 0) {
133+
return INFINI_STATUS_SUCCESS;
134+
}
135+
136+
auto hc_stream = reinterpret_cast<hcStream_t>(stream);
137+
138+
// Dynamic block size optimization based on embedding_dim for Metax platform
139+
size_t block_size = 256; // Default block size for Metax
140+
if (_embedding_dim <= 64) {
141+
block_size = 512; // Small embedding_dim: use larger block for better occupancy
142+
} else if (_embedding_dim >= 1024) {
143+
block_size = 128; // Large embedding_dim: use smaller block to reduce register pressure
144+
}
145+
146+
size_t grid_size = (_num_indices + block_size - 1) / block_size;
147+
148+
// Launch kernel based on dtypes for Metax platform
149+
if (_input_dtype == INFINI_DTYPE_I32) {
150+
const int32_t *indices_ptr = reinterpret_cast<const int32_t *>(input);
151+
152+
if (_weight_dtype == INFINI_DTYPE_F32) {
153+
embeddingKernel<float, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
154+
reinterpret_cast<float *>(output),
155+
indices_ptr,
156+
reinterpret_cast<const float *>(weight),
157+
_num_indices,
158+
_embedding_dim,
159+
_vocab_size);
160+
} else if (_weight_dtype == INFINI_DTYPE_F16) {
161+
embeddingKernel<half, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
162+
reinterpret_cast<half *>(output),
163+
indices_ptr,
164+
reinterpret_cast<const half *>(weight),
165+
_num_indices,
166+
_embedding_dim,
167+
_vocab_size);
168+
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
169+
// Use Metax's bfloat16 type
170+
embeddingKernel<__hpcc_bfloat16, int32_t><<<grid_size, block_size, 0, hc_stream>>>(
171+
reinterpret_cast<__hpcc_bfloat16 *>(output),
172+
indices_ptr,
173+
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
174+
_num_indices,
175+
_embedding_dim,
176+
_vocab_size);
177+
} else {
178+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
179+
}
180+
} else if (_input_dtype == INFINI_DTYPE_I64) {
181+
const int64_t *indices_ptr = reinterpret_cast<const int64_t *>(input);
182+
183+
if (_weight_dtype == INFINI_DTYPE_F32) {
184+
embeddingKernel<float, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
185+
reinterpret_cast<float *>(output),
186+
indices_ptr,
187+
reinterpret_cast<const float *>(weight),
188+
_num_indices,
189+
_embedding_dim,
190+
_vocab_size);
191+
} else if (_weight_dtype == INFINI_DTYPE_F16) {
192+
embeddingKernel<half, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
193+
reinterpret_cast<half *>(output),
194+
indices_ptr,
195+
reinterpret_cast<const half *>(weight),
196+
_num_indices,
197+
_embedding_dim,
198+
_vocab_size);
199+
} else if (_weight_dtype == INFINI_DTYPE_BF16) {
200+
embeddingKernel<__hpcc_bfloat16, int64_t><<<grid_size, block_size, 0, hc_stream>>>(
201+
reinterpret_cast<__hpcc_bfloat16 *>(output),
202+
indices_ptr,
203+
reinterpret_cast<const __hpcc_bfloat16 *>(weight),
204+
_num_indices,
205+
_embedding_dim,
206+
_vocab_size);
207+
} else {
208+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
209+
}
210+
} else {
211+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
212+
}
213+
214+
return INFINI_STATUS_SUCCESS;
215+
}
216+
217+
} // namespace op::embedding::metax
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __EMBEDDING_MOORE_H__
2+
#define __EMBEDDING_MOORE_H__
3+
4+
#include "../embedding.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif // __EMBEDDING_MOORE_H__

0 commit comments

Comments
 (0)