11#include " infinicore/ops/embedding.hpp"
22#include " infinicore/context/context.hpp"
3+ #include " ../../utils.hpp"
34#include < cstring>
5+ #include < stdexcept>
46
57namespace infinicore ::op {
68
9+ common::OpDispatcher<Embedding::schema> &Embedding::dispatcher () {
10+ static common::OpDispatcher<Embedding::schema> dispatcher_;
11+ return dispatcher_;
12+ }
13+
14+ void Embedding::execute (Tensor out, Tensor input, Tensor weight) {
15+ // Check that output and weight are on the same device
16+ INFINICORE_ASSERT_TENSORS_SAME_DEVICE (out, weight);
17+
18+ // Set device context
19+ infinicore::context::setDevice (out->device ());
20+
21+ // Use dispatcher to lookup kernel (infiniop implementation)
22+ dispatcher ().lookup (out->device ().getType ())(out, input, weight);
23+ }
24+
725Tensor embedding (Tensor input, // LongTensor of arbitrary shape containing the indices to extract
826 Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
927) {
1028 auto input_shape = input->shape ();
1129 auto weight_shape = weight->shape ();
12- // auto vocab_size = weight_shape[0];
1330 auto embedding_dim = weight_shape[1 ];
1431
1532 // Assign memory to out variables
@@ -22,68 +39,7 @@ Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the i
2239}
2340
2441void embedding_ (Tensor out, Tensor input, Tensor weight) {
25- assert (infinicore::DataType::I64 == input->dtype () || (infinicore::DataType::I32 == input->dtype ()));
26- assert (infinicore::Device::Type::CPU == input->device ().getType ());
27-
28- auto input_shape = input->shape ();
29- auto weight_shape = weight->shape ();
30- auto embedding_dim = weight_shape[1 ];
31-
32- // Calculate the number of token
33- Size counts = 1 ;
34- for (auto &v : input_shape) {
35- counts *= v;
36- }
37-
38- // the bytes of one token
39- const Size bytes = dsize (weight->dtype ()) * embedding_dim;
40- auto *weight_ptr = weight->data ();
41- auto *out_ptr = out->data ();
42-
43- // copies
44- if (weight->device ().getType () == Device::Type::CPU) {
45- if (infinicore::DataType::I64 == input->dtype ()) {
46- const int64_t *input_arr = reinterpret_cast <const int64_t *>(input->data ());
47- for (Size i = 0 ; i < counts; ++i) {
48- int64_t idx = input_arr[i];
49- assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
50- std::memcpy (out_ptr + i * bytes,
51- weight_ptr + idx * bytes,
52- bytes);
53- }
54- } else if (infinicore::DataType::I32 == input->dtype ()) {
55- const int32_t *input_arr = reinterpret_cast <const int32_t *>(input->data ());
56-
57- for (Size i = 0 ; i < counts; ++i) {
58- int32_t idx = input_arr[i];
59- assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
60- std::memcpy (out_ptr + i * bytes,
61- weight_ptr + idx * bytes,
62- bytes);
63- }
64- }
65-
66- } else {
67- if (infinicore::DataType::I64 == input->dtype ()) {
68- const int64_t *input_arr = reinterpret_cast <const int64_t *>(input->data ());
69- for (Size i = 0 ; i < counts; ++i) {
70- int64_t idx = input_arr[i];
71- assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
72- context::memcpyD2D (out_ptr + i * bytes,
73- weight_ptr + idx * bytes,
74- bytes);
75- }
76- } else if (infinicore::DataType::I32 == input->dtype ()) {
77- const int32_t *input_arr = reinterpret_cast <const int32_t *>(input->data ());
78- for (Size i = 0 ; i < counts; ++i) {
79- int32_t idx = input_arr[i];
80- assert ((idx >= 0 ) && (idx < weight_shape[0 ]));
81- context::memcpyD2D (out_ptr + i * bytes,
82- weight_ptr + idx * bytes,
83- bytes);
84- }
85- }
86- }
42+ Embedding::execute (out, input, weight);
8743}
8844
8945} // namespace infinicore::op
0 commit comments