Skip to content

Commit e6fbcc8

Browse files
Merge pull request #585 from pengcheng888/issue/584
issue/584 - 添加python的rope的测试, embeddding的实现和测试
2 parents 28b1a1b + 5c4747c commit e6fbcc8

File tree

13 files changed

+582
-11
lines changed

13 files changed

+582
-11
lines changed
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#pragma once
2+
3+
#include "common/op.hpp"
4+
5+
namespace infinicore::op {
6+
7+
Tensor embedding(Tensor input, Tensor weight);
8+
void embedding_(Tensor out, Tensor input, Tensor weight);
9+
} // namespace infinicore::op

include/infinicore/ops/linear.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "common/op.hpp"
4+
#include <optional>
45

56
namespace infinicore::op {
67

include/infinicore/ops/rope.hpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
11
#pragma once
22

33
#include "../device.hpp"
4-
#include "../tensor.hpp"
54
#include "../nn/rope.hpp"
5+
#include "../tensor.hpp"
66
#include "common/op.hpp"
77

88
namespace infinicore::op {
99
class RoPE {
1010
public:
1111
using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo);
12-
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
12+
static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
1313
static common::OpDispatcher<schema> &dispatcher();
1414
};
1515

1616
// Internal function
17-
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
17+
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
1818

1919
// Public API that uses infinicore::nn::RoPE::Algo
20-
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo);
20+
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo);
2121
} // namespace infinicore::op
Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,20 @@
11
from .causal_softmax import causal_softmax
2+
from .embedding import embedding
23
from .linear import linear
34
from .random_sample import random_sample
45
from .rms_norm import rms_norm
6+
from .rope import RopeAlgo, rope
57
from .silu import silu
68
from .swiglu import swiglu
79

8-
__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu", "linear"]
10+
__all__ = [
11+
"causal_softmax",
12+
"random_sample",
13+
"rms_norm",
14+
"silu",
15+
"swiglu",
16+
"linear",
17+
"embedding",
18+
"rope",
19+
"RopeAlgo",
20+
]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["embedding"]
5+
6+
7+
def embedding(
8+
input: Tensor,
9+
weight: Tensor,
10+
padding_idx=None,
11+
max_norm=None,
12+
norm_type=2.0,
13+
scale_grad_by_freq=False,
14+
sparse=False,
15+
*,
16+
out=None,
17+
) -> Tensor:
18+
r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size."""
19+
20+
assert (
21+
(padding_idx is None)
22+
and (max_norm is None)
23+
and (scale_grad_by_freq is False)
24+
and (sparse is False)
25+
), "Unsupported parameters."
26+
27+
assert "cpu" == input.device.type, (
28+
"The device of 'input' variable must be on the CPU."
29+
)
30+
31+
if out is None:
32+
return Tensor(_infinicore.embedding(input._underlying, weight._underlying))
33+
34+
_infinicore.embedding_(out._underlying, input._underlying, weight._underlying)
35+
return out
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from infinicore.lib import _infinicore
2+
from infinicore.tensor import Tensor
3+
4+
__all__ = ["rope", "RopeAlgo"]
5+
6+
7+
class RopeAlgo:
8+
r"""Different types of RoPE algorithms."""
9+
10+
GPT_J = _infinicore.Algo.GPT_J
11+
GPT_NEOX = _infinicore.Algo.GPT_NEOX
12+
13+
14+
def rope(
15+
x: Tensor,
16+
pos_ids: Tensor,
17+
sin_table: Tensor,
18+
cos_table: Tensor,
19+
algo: RopeAlgo = RopeAlgo.GPT_NEOX,
20+
*,
21+
out=None,
22+
) -> Tensor:
23+
r"""Rotary Position Embedding(RoPE)."""
24+
25+
if out is None:
26+
return Tensor(
27+
_infinicore.rope(
28+
x._underlying,
29+
pos_ids._underlying,
30+
sin_table._underlying,
31+
cos_table._underlying,
32+
algo,
33+
)
34+
)
35+
36+
_infinicore.rope_(
37+
out._underlying,
38+
x._underlying,
39+
pos_ids._underlying,
40+
sin_table._underlying,
41+
cos_table._underlying,
42+
algo,
43+
)
44+
return out
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
#include "infinicore/ops/embedding.hpp"
2+
#include "infinicore/context/context.hpp"
3+
#include <cstring>
4+
5+
namespace infinicore::op {
6+
7+
Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract
8+
Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1
9+
) {
10+
auto input_shape = input->shape();
11+
auto weight_shape = weight->shape();
12+
auto vocab_size = weight_shape[0];
13+
auto embedding_dim = weight_shape[1];
14+
15+
// Assign memory to out variables
16+
auto output_shape = input_shape;
17+
output_shape.push_back(embedding_dim);
18+
Tensor inputs_embeds = Tensor::empty(output_shape, weight->dtype(), weight->device());
19+
20+
embedding_(inputs_embeds, input, weight);
21+
return inputs_embeds;
22+
}
23+
24+
void embedding_(Tensor out, Tensor input, Tensor weight) {
25+
assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype()));
26+
assert(infinicore::Device::Type::CPU == input->device());
27+
28+
auto input_shape = input->shape();
29+
auto weight_shape = weight->shape();
30+
auto vocab_size = weight_shape[0];
31+
auto embedding_dim = weight_shape[1];
32+
33+
// Calculate the number of token
34+
Size counts = 1;
35+
for (auto &v : input_shape) {
36+
counts *= v;
37+
}
38+
39+
// the bytes of one token
40+
const Size bytes = dsize(weight->dtype()) * embedding_dim;
41+
auto *weight_ptr = weight->data();
42+
auto *out_ptr = out->data();
43+
44+
// copies
45+
if (weight->device().getType() == Device::Type::CPU) {
46+
if (infinicore::DataType::I64 == input->dtype()) {
47+
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
48+
for (Size i = 0; i < counts; ++i) {
49+
int64_t idx = input_arr[i];
50+
assert((idx >= 0) && (idx < vocab_size));
51+
std::memcpy(out_ptr + i * bytes,
52+
weight_ptr + idx * bytes,
53+
bytes);
54+
}
55+
} else if (infinicore::DataType::I32 == input->dtype()) {
56+
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
57+
58+
for (Size i = 0; i < counts; ++i) {
59+
int32_t idx = input_arr[i];
60+
assert((idx >= 0) && (idx < vocab_size));
61+
std::memcpy(out_ptr + i * bytes,
62+
weight_ptr + idx * bytes,
63+
bytes);
64+
}
65+
}
66+
67+
} else {
68+
if (infinicore::DataType::I64 == input->dtype()) {
69+
const int64_t *input_arr = reinterpret_cast<const int64_t *>(input->data());
70+
for (Size i = 0; i < counts; ++i) {
71+
int64_t idx = input_arr[i];
72+
assert((idx >= 0) && (idx < vocab_size));
73+
context::memcpyD2D(out_ptr + i * bytes,
74+
weight_ptr + idx * bytes,
75+
bytes);
76+
}
77+
} else if (infinicore::DataType::I32 == input->dtype()) {
78+
const int32_t *input_arr = reinterpret_cast<const int32_t *>(input->data());
79+
for (Size i = 0; i < counts; ++i) {
80+
int32_t idx = input_arr[i];
81+
assert((idx >= 0) && (idx < vocab_size));
82+
context::memcpyD2D(out_ptr + i * bytes,
83+
weight_ptr + idx * bytes,
84+
bytes);
85+
}
86+
}
87+
}
88+
}
89+
90+
} // namespace infinicore::op

src/infinicore/ops/rope/rope.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,25 +9,25 @@ common::OpDispatcher<RoPE::schema> &RoPE::dispatcher() {
99
return dispatcher_;
1010
};
1111

12-
void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
12+
void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) {
1313
auto device_type = context::getDevice().getType();
1414
auto func = dispatcher().lookup(device_type);
1515

1616
if (func == nullptr) {
1717
throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast<int>(device_type)));
1818
}
1919

20-
func(x_out, x, pos, sin_cache, cos_cache, algo);
20+
func(x_out, x, pos, sin_table, cos_table, algo);
2121
}
2222

23-
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
24-
RoPE::execute(x_out, x, pos, sin_cache, cos_cache, algo);
23+
void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) {
24+
RoPE::execute(x_out, x, pos, sin_table, cos_table, algo);
2525
}
2626

27-
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) {
27+
Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) {
2828
Shape shape = x->shape();
2929
auto x_out = Tensor::empty(shape, x->dtype(), x->device());
30-
rope_(x_out, x, pos, sin_cache, cos_cache, algo);
30+
rope_(x_out, x, pos, sin_table, cos_table, algo);
3131
return x_out;
3232
}
3333

src/infinicore/pybind11/ops.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
#include "ops/add.hpp"
66
#include "ops/attention.hpp"
77
#include "ops/causal_softmax.hpp"
8+
#include "ops/embedding.hpp"
89
#include "ops/linear.hpp"
910
#include "ops/matmul.hpp"
1011
#include "ops/mul.hpp"
1112
#include "ops/random_sample.hpp"
1213
#include "ops/rearrange.hpp"
1314
#include "ops/rms_norm.hpp"
15+
#include "ops/rope.hpp"
1416
#include "ops/silu.hpp"
1517
#include "ops/swiglu.hpp"
1618

@@ -30,6 +32,8 @@ inline void bind(py::module &m) {
3032
bind_rms_norm(m);
3133
bind_silu(m);
3234
bind_swiglu(m);
35+
bind_rope(m);
36+
bind_embedding(m);
3337
}
3438

3539
} // namespace infinicore::ops
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#pragma once
2+
3+
#include "infinicore/ops/embedding.hpp"
4+
#include <pybind11/pybind11.h>
5+
6+
namespace py = pybind11;
7+
8+
namespace infinicore::ops {
9+
10+
inline void bind_embedding(py::module &m) {
11+
12+
m.def("embedding",
13+
&op::embedding,
14+
py::arg("input"),
15+
py::arg("weight"),
16+
R"doc(Generate a simple lookup table that looks up embeddings in a fixed dictionary and size..)doc");
17+
18+
m.def("embedding_",
19+
&op::embedding_,
20+
py::arg("out"),
21+
py::arg("input"),
22+
py::arg("weight"),
23+
R"doc(In-place, Generate a simple lookup table that looks up embeddings in a fixed dictionary and size..)doc");
24+
}
25+
26+
} // namespace infinicore::ops

0 commit comments

Comments
 (0)