Skip to content

Commit bf120b2

Browse files
committed
issue/900 - maintains classic embedding for devices yet to be worked on
1 parent d3bae33 commit bf120b2

File tree

2 files changed

+79
-14
lines changed

2 files changed

+79
-14
lines changed

src/infinicore/nn/embedding.cc

Lines changed: 78 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,20 +43,87 @@ Embedding::Embedding(size_t num_embeddings,
4343
}
4444

4545
Tensor Embedding::forward(const Tensor &indices) const {
46-
// Ensure indices are on the same device as weight
47-
// This avoids synchronous memcpy in ops layer which would hurt performance
48-
Tensor indices_on_device = indices;
49-
if (indices->device() != device_) {
50-
indices_on_device = indices->to(device_);
46+
// TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
47+
auto device_type = device_.getType();
48+
if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE) {
49+
// Use op::embedding which supports device-side input and batch dimension
50+
return op::embedding(indices->contiguous()->to(device_), weight_);
5151
}
5252

53-
// Ensure indices are contiguous for efficient access
54-
// op::embedding now supports device-side input for graph recording
55-
Tensor indices_contiguous = indices_on_device->is_contiguous() ? indices_on_device : indices_on_device->contiguous();
53+
// Get the shape of indices
54+
auto indices_shape = indices->shape();
5655

57-
// Use op::embedding which now supports device-side input and batch dimension
58-
// This enables full graph recording support without synchronization
59-
return op::embedding(indices_contiguous, weight_);
56+
// Output shape: indices_shape + [embedding_dim]
57+
std::vector<size_t> output_shape = indices_shape;
58+
output_shape.push_back(embedding_dim_);
59+
60+
// Create output tensor on the same device as weight
61+
auto out = Tensor::empty(output_shape, weight_->dtype(), weight_->device());
62+
63+
// Flatten indices for sequential row copies
64+
auto cpu_device = Device(Device::Type::CPU, 0);
65+
auto indices_cpu = indices->to(cpu_device)->contiguous();
66+
67+
// Calculate total number of lookups
68+
size_t num_lookups = 1;
69+
for (auto dim : indices_shape) {
70+
num_lookups *= dim;
71+
}
72+
73+
const size_t row_bytes = embedding_dim_ * dsize(weight_->dtype());
74+
75+
// Source and destination base pointers
76+
auto *weight_base = weight_->data();
77+
auto *out_base = out->data();
78+
79+
// Helper lambda to read index based on dtype with bounds checking
80+
auto read_index = [&](size_t i) -> int64_t {
81+
auto dtype = indices_cpu->dtype();
82+
if (dtype == DataType::I32) {
83+
const auto *data = reinterpret_cast<const int32_t *>(indices_cpu->data());
84+
return static_cast<int64_t>(data[i]);
85+
} else if (dtype == DataType::I64) {
86+
const auto *data = reinterpret_cast<const int64_t *>(indices_cpu->data());
87+
return data[i];
88+
} else if (dtype == DataType::U32) {
89+
const auto *data = reinterpret_cast<const uint32_t *>(indices_cpu->data());
90+
return static_cast<int64_t>(data[i]);
91+
} else if (dtype == DataType::U64) {
92+
const auto *data = reinterpret_cast<const uint64_t *>(indices_cpu->data());
93+
uint64_t val = data[i];
94+
// Check if value can fit in int64_t
95+
if (val > static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
96+
throw std::out_of_range("Index value out of range for int64_t: " + std::to_string(val));
97+
}
98+
return static_cast<int64_t>(val);
99+
} else {
100+
throw std::runtime_error("Embedding indices must be integer type, got dtype=" + std::to_string(static_cast<int>(dtype)));
101+
}
102+
};
103+
104+
if (weight_->device().getType() == Device::Type::CPU) {
105+
// CPU path: memcpy row by row
106+
for (size_t i = 0; i < num_lookups; ++i) {
107+
int64_t idx = read_index(i);
108+
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
109+
throw std::out_of_range(
110+
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
111+
}
112+
std::memcpy(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
113+
}
114+
} else {
115+
// Device path: use stream-ordered D2D copies
116+
for (size_t i = 0; i < num_lookups; ++i) {
117+
int64_t idx = read_index(i);
118+
if (idx < 0 || idx >= static_cast<int64_t>(num_embeddings_)) {
119+
throw std::out_of_range(
120+
"Index out of range: " + std::to_string(idx) + " (num_embeddings=" + std::to_string(num_embeddings_) + ")");
121+
}
122+
context::memcpyD2D(out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
123+
}
124+
}
125+
126+
return out;
60127
}
61128

62129
std::string Embedding::extra_repr() const {

src/infinicore/ops/embedding/embedding.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
#include "infinicore/ops/embedding.hpp"
2+
23
#include "../../utils.hpp"
3-
#include "infinicore/context/context.hpp"
4-
#include <cstring>
5-
#include <stdexcept>
64

75
namespace infinicore::op {
86
INFINICORE_GRAPH_OP_DISPATCHERS_IMPL(Embedding);

0 commit comments

Comments
 (0)