|
| 1 | +#include "infinicore/ops/embedding.hpp" |
| 2 | +#include "infinicore/context/context.hpp" |
| 3 | +#include <cstring> |
| 4 | + |
| 5 | +namespace infinicore::op { |
| 6 | + |
| 7 | +Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract |
| 8 | + Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 |
| 9 | +) { |
| 10 | + auto input_shape = input->shape(); |
| 11 | + auto weight_shape = weight->shape(); |
| 12 | + auto vocab_size = weight_shape[0]; |
| 13 | + auto embedding_dim = weight_shape[1]; |
| 14 | + |
| 15 | + // Assign memory to out variables |
| 16 | + auto output_shape = input_shape; |
| 17 | + output_shape.push_back(embedding_dim); |
| 18 | + Tensor inputs_embeds = Tensor::empty(output_shape, weight->dtype(), weight->device()); |
| 19 | + |
| 20 | + embedding_(inputs_embeds, input, weight); |
| 21 | + return inputs_embeds; |
| 22 | +} |
| 23 | + |
| 24 | +void embedding_(Tensor out, Tensor input, Tensor weight) { |
| 25 | + assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); |
| 26 | + assert(infinicore::Device::Type::CPU == input->device()); |
| 27 | + |
| 28 | + auto input_shape = input->shape(); |
| 29 | + auto weight_shape = weight->shape(); |
| 30 | + auto vocab_size = weight_shape[0]; |
| 31 | + auto embedding_dim = weight_shape[1]; |
| 32 | + |
| 33 | + // Calculate the number of token |
| 34 | + Size counts = 1; |
| 35 | + for (auto &v : input_shape) { |
| 36 | + counts *= v; |
| 37 | + } |
| 38 | + |
| 39 | + // the bytes of one token |
| 40 | + const Size bytes = dsize(weight->dtype()) * embedding_dim; |
| 41 | + auto *weight_ptr = weight->data(); |
| 42 | + auto *out_ptr = out->data(); |
| 43 | + |
| 44 | + // copies |
| 45 | + if (weight->device().getType() == Device::Type::CPU) { |
| 46 | + if (infinicore::DataType::I64 == input->dtype()) { |
| 47 | + const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data()); |
| 48 | + for (Size i = 0; i < counts; ++i) { |
| 49 | + int64_t idx = input_arr[i]; |
| 50 | + assert((idx >= 0) && (idx < vocab_size)); |
| 51 | + std::memcpy(out_ptr + i * bytes, |
| 52 | + weight_ptr + idx * bytes, |
| 53 | + bytes); |
| 54 | + } |
| 55 | + } else if (infinicore::DataType::I32 == input->dtype()) { |
| 56 | + const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data()); |
| 57 | + |
| 58 | + for (Size i = 0; i < counts; ++i) { |
| 59 | + int32_t idx = input_arr[i]; |
| 60 | + assert((idx >= 0) && (idx < vocab_size)); |
| 61 | + std::memcpy(out_ptr + i * bytes, |
| 62 | + weight_ptr + idx * bytes, |
| 63 | + bytes); |
| 64 | + } |
| 65 | + } |
| 66 | + |
| 67 | + } else { |
| 68 | + if (infinicore::DataType::I64 == input->dtype()) { |
| 69 | + const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data()); |
| 70 | + for (Size i = 0; i < counts; ++i) { |
| 71 | + int64_t idx = input_arr[i]; |
| 72 | + assert((idx >= 0) && (idx < vocab_size)); |
| 73 | + context::memcpyD2D(out_ptr + i * bytes, |
| 74 | + weight_ptr + idx * bytes, |
| 75 | + bytes); |
| 76 | + } |
| 77 | + } else if (infinicore::DataType::I32 == input->dtype()) { |
| 78 | + const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data()); |
| 79 | + for (Size i = 0; i < counts; ++i) { |
| 80 | + int32_t idx = input_arr[i]; |
| 81 | + assert((idx >= 0) && (idx < vocab_size)); |
| 82 | + context::memcpyD2D(out_ptr + i * bytes, |
| 83 | + weight_ptr + idx * bytes, |
| 84 | + bytes); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | +} |
| 89 | + |
| 90 | +} // namespace infinicore::op |
0 commit comments