@@ -43,20 +43,87 @@ Embedding::Embedding(size_t num_embeddings,
4343}
4444
4545Tensor Embedding::forward (const Tensor &indices) const {
46- // Ensure indices are on the same device as weight
47- // This avoids synchronous memcpy in ops layer which would hurt performance
48- Tensor indices_on_device = indices;
49- if (indices-> device () != device_) {
50- indices_on_device = indices->to (device_);
46+ // TODO: Implement on-device embedding for all devices, then remove the condition and the classic approach
47+ auto device_type = device_. getType ();
48+ if (device_type == Device::Type::NVIDIA || device_type == Device::Type::ILUVATAR || device_type == Device::Type::METAX || device_type == Device::Type::MOORE) {
49+ // Use op::embedding which supports device-side input and batch dimension
50+ return op::embedding ( indices->contiguous ()-> to (device_), weight_ );
5151 }
5252
53- // Ensure indices are contiguous for efficient access
54- // op::embedding now supports device-side input for graph recording
55- Tensor indices_contiguous = indices_on_device->is_contiguous () ? indices_on_device : indices_on_device->contiguous ();
53+ // Get the shape of indices
54+ auto indices_shape = indices->shape ();
5655
57- // Use op::embedding which now supports device-side input and batch dimension
58- // This enables full graph recording support without synchronization
59- return op::embedding (indices_contiguous, weight_);
56+ // Output shape: indices_shape + [embedding_dim]
57+ std::vector<size_t > output_shape = indices_shape;
58+ output_shape.push_back (embedding_dim_);
59+
60+ // Create output tensor on the same device as weight
61+ auto out = Tensor::empty (output_shape, weight_->dtype (), weight_->device ());
62+
63+ // Flatten indices for sequential row copies
64+ auto cpu_device = Device (Device::Type::CPU, 0 );
65+ auto indices_cpu = indices->to (cpu_device)->contiguous ();
66+
67+ // Calculate total number of lookups
68+ size_t num_lookups = 1 ;
69+ for (auto dim : indices_shape) {
70+ num_lookups *= dim;
71+ }
72+
73+ const size_t row_bytes = embedding_dim_ * dsize (weight_->dtype ());
74+
75+ // Source and destination base pointers
76+ auto *weight_base = weight_->data ();
77+ auto *out_base = out->data ();
78+
79+ // Helper lambda to read index based on dtype with bounds checking
80+ auto read_index = [&](size_t i) -> int64_t {
81+ auto dtype = indices_cpu->dtype ();
82+ if (dtype == DataType::I32) {
83+ const auto *data = reinterpret_cast <const int32_t *>(indices_cpu->data ());
84+ return static_cast <int64_t >(data[i]);
85+ } else if (dtype == DataType::I64) {
86+ const auto *data = reinterpret_cast <const int64_t *>(indices_cpu->data ());
87+ return data[i];
88+ } else if (dtype == DataType::U32) {
89+ const auto *data = reinterpret_cast <const uint32_t *>(indices_cpu->data ());
90+ return static_cast <int64_t >(data[i]);
91+ } else if (dtype == DataType::U64) {
92+ const auto *data = reinterpret_cast <const uint64_t *>(indices_cpu->data ());
93+ uint64_t val = data[i];
94+ // Check if value can fit in int64_t
95+ if (val > static_cast <uint64_t >(std::numeric_limits<int64_t >::max ())) {
96+ throw std::out_of_range (" Index value out of range for int64_t: " + std::to_string (val));
97+ }
98+ return static_cast <int64_t >(val);
99+ } else {
100+ throw std::runtime_error (" Embedding indices must be integer type, got dtype=" + std::to_string (static_cast <int >(dtype)));
101+ }
102+ };
103+
104+ if (weight_->device ().getType () == Device::Type::CPU) {
105+ // CPU path: memcpy row by row
106+ for (size_t i = 0 ; i < num_lookups; ++i) {
107+ int64_t idx = read_index (i);
108+ if (idx < 0 || idx >= static_cast <int64_t >(num_embeddings_)) {
109+ throw std::out_of_range (
110+ " Index out of range: " + std::to_string (idx) + " (num_embeddings=" + std::to_string (num_embeddings_) + " )" );
111+ }
112+ std::memcpy (out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
113+ }
114+ } else {
115+ // Device path: use stream-ordered D2D copies
116+ for (size_t i = 0 ; i < num_lookups; ++i) {
117+ int64_t idx = read_index (i);
118+ if (idx < 0 || idx >= static_cast <int64_t >(num_embeddings_)) {
119+ throw std::out_of_range (
120+ " Index out of range: " + std::to_string (idx) + " (num_embeddings=" + std::to_string (num_embeddings_) + " )" );
121+ }
122+ context::memcpyD2D (out_base + i * row_bytes, weight_base + idx * row_bytes, row_bytes);
123+ }
124+ }
125+
126+ return out;
60127}
61128
62129std::string Embedding::extra_repr () const {
0 commit comments