Skip to content

Commit dce1995

Browse files
committed
Issue/846 - Refactor embedding to support device-side input and CUDA graph recording
1 parent 0ead67f commit dce1995

File tree

20 files changed

+1387
-165
lines changed

20 files changed

+1387
-165
lines changed

include/infinicore/ops.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "ops/add.hpp"
44
#include "ops/attention.hpp"
55
#include "ops/causal_softmax.hpp"
6+
#include "ops/embedding.hpp"
67
#include "ops/matmul.hpp"
78
#include "ops/ones.hpp"
89
#include "ops/rearrange.hpp"

include/infinicore/ops/embedding.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
namespace infinicore::op {
66

7+
class Embedding {
8+
public:
9+
using schema = void (*)(Tensor, Tensor, Tensor);
10+
static void execute(Tensor out, Tensor input, Tensor weight);
11+
static common::OpDispatcher<schema> &dispatcher();
12+
};
13+
714
Tensor embedding(Tensor input, Tensor weight);
815
void embedding_(Tensor out, Tensor input, Tensor weight);
916
} // namespace infinicore::op

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "infiniop/ops/clip.h"
99
#include "infiniop/ops/conv.h"
1010
#include "infiniop/ops/dequantize_awq.h"
11+
#include "infiniop/ops/embedding.h"
1112
#include "infiniop/ops/gelu.h"
1213
#include "infiniop/ops/gemm.h"
1314
#include "infiniop/ops/layer_norm.h"

include/infiniop/ops/embedding.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef __INFINIOP_EMBEDDING_API_H__
2+
#define __INFINIOP_EMBEDDING_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopEmbeddingDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateEmbeddingDescriptor(
9+
infiniopHandle_t handle,
10+
infiniopEmbeddingDescriptor_t *desc_ptr,
11+
infiniopTensorDescriptor_t output_desc,
12+
infiniopTensorDescriptor_t input_desc,
13+
infiniopTensorDescriptor_t weight_desc);
14+
15+
__C __export infiniStatus_t infiniopEmbedding(
16+
infiniopEmbeddingDescriptor_t desc,
17+
void *output,
18+
const void *input,
19+
const void *weight,
20+
void *stream);
21+
22+
__C __export infiniStatus_t infiniopDestroyEmbeddingDescriptor(
23+
infiniopEmbeddingDescriptor_t desc);
24+
25+
#endif
26+

python/infinicore/nn/functional/embedding.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ def embedding(
2222
and (sparse is False)
2323
), "Unsupported parameters."
2424

25-
assert "cpu" == input.device.type, (
26-
"The device of 'input' variable must be on the CPU."
27-
)
25+
# Note: embedding now supports device-side input for graph recording
26+
# The C++ implementation handles both CPU and device-side inputs
2827

2928
if out is None:
3029
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))

src/infinicore/nn/embedding.cc

Lines changed: 7 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -43,80 +43,13 @@ Embedding::Embedding(size_t num_embeddings,
4343
}
4444

4545
Tensor Embedding::forward(const Tensor &indices) const {
46-
// Get the shape of indices
47-
auto indices_shape = indices->shape();
48-
49-
// Output shape: indices_shape + [embedding_dim]
50-
std::vector<size_t> output_shape = indices_shape;
51-
output_shape.push_back(embedding_dim_);
52-
53-
// Create output tensor on the same device as weight
54-
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());
55-
56-
// Flatten indices for sequential row copies
57-
auto cpu_device = Device(Device::Type::CPU, 0);
58-
auto indices_cpu = indices->to(cpu_device)->contiguous();
59-
60-
// Calculate total number of lookups
61-
size_t num_lookups = 1;
62-
for (auto dim : indices_shape) {
63-
num_lookups *= dim;
64-
}
65-
66-
const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype());
67-
68-
// Source and destination base pointers
69-
auto *weight_base = weight_->data();
70-
auto *out_base = out->data();
71-
72-
// Helper lambda to read index based on dtype with bounds checking
73-
auto read_index = [&](size_t i) -> int64_t {
74-
auto dtype = indices_cpu->dtype();
75-
if (dtype == DataType::I32) {
76-
const auto *data = reinterpret_cast<const int32_t *>(indices_cpu->data());
77-
return static_cast<int64_t>(data[i]);
78-
} else if (dtype == DataType::I64) {
79-
const auto *data = reinterpret_cast<const int64_t *>(indices_cpu->data());
80-
return data[i];
81-
} else if (dtype == DataType::U32) {
82-
const auto *data = reinterpret_cast<const uint32_t *>(indices_cpu->data());
83-
return static_cast<int64_t>(data[i]);
84-
} else if (dtype == DataType::U64) {
85-
const auto *data = reinterpret_cast<const uint64_t *>(indices_cpu->data());
86-
uint64_t val = data[i];
87-
// Check if value can fit in int64_t
88-
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
89-
throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val));
90-
}
91-
return static_cast<int64_t>(val);
92-
} else {
93-
throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast<int>(dtype)));
94-
}
95-
};
96-
97-
if (weight_->device().getType() == Device::Type::CPU) {
98-
// CPU path: memcpy row by row
99-
for (size_t i = 0; i < num_lookups; ++i) {
100-
int64_t idx = read_index(i);
101-
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
102-
throw std::out_of_range(
103-
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
104-
}
105-
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
106-
}
107-
} else {
108-
// Device path: use stream-ordered D2D copies
109-
for (size_t i = 0; i < num_lookups; ++i) {
110-
int64_t idx = read_index(i);
111-
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
112-
throw std::out_of_range(
113-
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
114-
}
115-
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
116-
}
117-
}
118-
119-
return out;
46+
// Ensure indices are contiguous for efficient access
47+
// op::embedding now supports device-side input for graph recording
48+
Tensor indices_contiguous = indices->is_contiguous() ? indices : indices->contiguous();
49+
50+
// Use op::embedding which now supports device-side input and batch dimension
51+
// This enables full graph recording support without synchronization
52+
return op::embedding(indices_contiguous, weight_);
12053
}
12154

12255
std::string Embedding::extra_repr() const {
Lines changed: 19 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,32 @@
11
#include "infinicore/ops/embedding.hpp"
22
#include "infinicore/context/context.hpp"
3+
#include "../../utils.hpp"
34
#include <cstring>
5+
#include <stdexcept>
46

57
namespace infinicore::op {
68

9+
common::OpDispatcher<Embedding::schema> &Embedding::dispatcher() {
10+
static common::OpDispatcher<Embedding::schema> dispatcher_;
11+
return dispatcher_;
12+
}
13+
14+
void Embedding::execute(Tensor out, Tensor input, Tensor weight) {
15+
// Check that output and weight are on the same device
16+
INFINICORE_ASSERT_TENSORS_SAME_DEVICE(out, weight);
17+
18+
// Set device context
19+
infinicore::context::setDevice(out->device());
20+
21+
// Use dispatcher to lookup kernel (infiniop implementation)
22+
dispatcher().lookup(out->device().getType())(out, input, weight);
23+
}
24+
725
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
826
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
927
) {
1028
auto input_shape = input->shape();
1129
auto weight_shape = weight->shape();
12-
// auto vocab_size = weight_shape[0];
1330
auto embedding_dim = weight_shape[1];
1431

1532
// Assign memory to out variables
@@ -22,68 +39,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
2239
}
2340

2441
void embedding_(Tensor out, Tensor input, Tensor weight) {
25-
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
26-
assert(infinicore::Device::Type::CPU == input->device().getType());
27-
28-
auto input_shape = input->shape();
29-
auto weight_shape = weight->shape();
30-
auto embedding_dim = weight_shape[1];
31-
32-
// Calculate the number of token
33-
Size counts = 1;
34-
for (auto &v : input_shape) {
35-
counts *= v;
36-
}
37-
38-
// the bytes of one token
39-
const Size bytes = dsize(weight->dtype()) * embedding_dim;
40-
auto *weight_ptr = weight->data();
41-
auto *out_ptr = out->data();
42-
43-
// copies
44-
if (weight->device().getType() == Device::Type::CPU) {
45-
if (infinicore::DataType::I64 == input->dtype()) {
46-
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
47-
for (Size i = 0; i < counts; ++i) {
48-
int64_t idx = input_arr[i];
49-
assert((idx >= 0) && (idx < weight_shape[0]));
50-
std::memcpy(out_ptr + i * bytes,
51-
weight_ptr + idx * bytes,
52-
bytes);
53-
}
54-
} else if (infinicore::DataType::I32 == input->dtype()) {
55-
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
56-
57-
for (Size i = 0; i < counts; ++i) {
58-
int32_t idx = input_arr[i];
59-
assert((idx >= 0) && (idx < weight_shape[0]));
60-
std::memcpy(out_ptr + i * bytes,
61-
weight_ptr + idx * bytes,
62-
bytes);
63-
}
64-
}
65-
66-
} else {
67-
if (infinicore::DataType::I64 == input->dtype()) {
68-
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
69-
for (Size i = 0; i < counts; ++i) {
70-
int64_t idx = input_arr[i];
71-
assert((idx >= 0) && (idx < weight_shape[0]));
72-
context::memcpyD2D(out_ptr + i * bytes,
73-
weight_ptr + idx * bytes,
74-
bytes);
75-
}
76-
} else if (infinicore::DataType::I32 == input->dtype()) {
77-
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
78-
for (Size i = 0; i < counts; ++i) {
79-
int32_t idx = input_arr[i];
80-
assert((idx >= 0) && (idx < weight_shape[0]));
81-
context::memcpyD2D(out_ptr + i * bytes,
82-
weight_ptr + idx * bytes,
83-
bytes);
84-
}
85-
}
86-
}
42+
Embedding::execute(out, input, weight);
8743
}
8844

8945
} // namespace infinicore::op
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "../../utils.hpp"
2+
#include "infinicore/common/hash.hpp"
3+
#include "infinicore/ops/embedding.hpp"
4+
#include "infinicore/ops/common/cache.hpp"
5+
#include <infiniop.h>
6+
7+
namespace infinicore::op::embedding_impl::infiniop {
8+
9+
thread_local common::OpCache<size_t, infiniopEmbeddingDescriptor_t> caches(
10+
100, // capacity
11+
[](infiniopEmbeddingDescriptor_t &desc) {
12+
if (desc != nullptr) {
13+
INFINICORE_CHECK_ERROR(infiniopDestroyEmbeddingDescriptor(desc));
14+
desc = nullptr;
15+
}
16+
});
17+
18+
void calculate(Tensor out, Tensor input, Tensor weight) {
19+
size_t seed = hash_combine(out, input, weight);
20+
21+
auto device = context::getDevice();
22+
auto &cache = caches.getCache(device);
23+
24+
auto desc_opt = cache.get(seed);
25+
infiniopEmbeddingDescriptor_t desc = nullptr;
26+
27+
if (!desc_opt) {
28+
INFINICORE_CHECK_ERROR(infiniopCreateEmbeddingDescriptor(
29+
context::getInfiniopHandle(device), &desc,
30+
out->desc(), input->desc(), weight->desc()));
31+
cache.put(seed, desc);
32+
} else {
33+
desc = *desc_opt;
34+
}
35+
36+
INFINICORE_CHECK_ERROR(infiniopEmbedding(
37+
desc,
38+
out->data(),
39+
input->data(),
40+
weight->data(),
41+
context::getStream()));
42+
}
43+
44+
static bool registered = []() {
45+
Embedding::dispatcher().registerAll(&calculate, false);
46+
return true;
47+
}();
48+
49+
} // namespace infinicore::op::embedding_impl::infiniop

0 commit comments

Comments
 (0)