diff --git a/include/infinicore/ops/embedding.hpp b/include/infinicore/ops/embedding.hpp new file mode 100644 index 000000000..4fd9991c4 --- /dev/null +++ b/include/infinicore/ops/embedding.hpp @@ -0,0 +1,9 @@ +#pragma once + +#include "common/op.hpp" + +namespace infinicore::op { + +Tensor embedding(Tensor input, Tensor weight); +void embedding_(Tensor out, Tensor input, Tensor weight); +} // namespace infinicore::op diff --git a/include/infinicore/ops/linear.hpp b/include/infinicore/ops/linear.hpp index 81cb61986..d69842be3 100644 --- a/include/infinicore/ops/linear.hpp +++ b/include/infinicore/ops/linear.hpp @@ -1,6 +1,7 @@ #pragma once #include "common/op.hpp" +#include namespace infinicore::op { diff --git a/include/infinicore/ops/rope.hpp b/include/infinicore/ops/rope.hpp index 339f95e45..a5f7792b9 100644 --- a/include/infinicore/ops/rope.hpp +++ b/include/infinicore/ops/rope.hpp @@ -1,21 +1,21 @@ #pragma once #include "../device.hpp" -#include "../tensor.hpp" #include "../nn/rope.hpp" +#include "../tensor.hpp" #include "common/op.hpp" namespace infinicore::op { class RoPE { public: using schema = void (*)(Tensor, const Tensor &, const Tensor &, const Tensor &, const Tensor &, infinicore::nn::RoPE::Algo); - static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); + static void execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); static common::OpDispatcher &dispatcher(); }; // Internal function -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); +void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); // Public API that uses infinicore::nn::RoPE::Algo -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo); +Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo); } // namespace infinicore::op diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py index d35257e2f..255079790 100644 --- a/python/infinicore/nn/functional/__init__.py +++ b/python/infinicore/nn/functional/__init__.py @@ -1,8 +1,20 @@ from .causal_softmax import causal_softmax +from .embedding import embedding from .linear import linear from .random_sample import random_sample from .rms_norm import rms_norm +from .rope import RopeAlgo, rope from .silu import silu from .swiglu import swiglu -__all__ = ["causal_softmax", "random_sample", "rms_norm", "silu", "swiglu", "linear"] +__all__ = [ + "causal_softmax", + "random_sample", + "rms_norm", + "silu", + "swiglu", + "linear", + "embedding", + "rope", + "RopeAlgo", +] diff --git a/python/infinicore/nn/functional/embedding.py b/python/infinicore/nn/functional/embedding.py new file mode 100644 index 000000000..73749e32c --- /dev/null +++ b/python/infinicore/nn/functional/embedding.py @@ -0,0 +1,35 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["embedding"] + + +def embedding( + input: Tensor, + weight: Tensor, + padding_idx=None, + max_norm=None, + norm_type=2.0, + scale_grad_by_freq=False, + sparse=False, + *, + out=None, +) -> Tensor: + r"""Generate a simple lookup table that looks up embeddings in a fixed dictionary and size.""" + + assert ( + (padding_idx is None) + and (max_norm is None) + and (scale_grad_by_freq is False) + and (sparse is False) + ), "Unsupported parameters." + + assert "cpu" == input.device.type, ( + "The device of 'input' variable must be on the CPU." + ) + + if out is None: + return Tensor(_infinicore.embedding(input._underlying, weight._underlying)) + + _infinicore.embedding_(out._underlying, input._underlying, weight._underlying) + return out diff --git a/python/infinicore/nn/functional/rope.py b/python/infinicore/nn/functional/rope.py new file mode 100644 index 000000000..93c76c963 --- /dev/null +++ b/python/infinicore/nn/functional/rope.py @@ -0,0 +1,44 @@ +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["rope", "RopeAlgo"] + + +class RopeAlgo: + r"""Different types of RoPE algorithms.""" + + GPT_J = _infinicore.Algo.GPT_J + GPT_NEOX = _infinicore.Algo.GPT_NEOX + + +def rope( + x: Tensor, + pos_ids: Tensor, + sin_table: Tensor, + cos_table: Tensor, + algo: RopeAlgo = RopeAlgo.GPT_NEOX, + *, + out=None, +) -> Tensor: + r"""Rotary Position Embedding(RoPE).""" + + if out is None: + return Tensor( + _infinicore.rope( + x._underlying, + pos_ids._underlying, + sin_table._underlying, + cos_table._underlying, + algo, + ) + ) + + _infinicore.rope_( + out._underlying, + x._underlying, + pos_ids._underlying, + sin_table._underlying, + cos_table._underlying, + algo, + ) + return out diff --git a/src/infinicore/ops/embedding/embedding.cc b/src/infinicore/ops/embedding/embedding.cc new file mode 100644 index 000000000..9500548f7 --- /dev/null +++ b/src/infinicore/ops/embedding/embedding.cc @@ -0,0 +1,90 @@ +#include "infinicore/ops/embedding.hpp" +#include "infinicore/context/context.hpp" +#include + +namespace infinicore::op { + +Tensor embedding(Tensor input, // LongTensor of arbitrary shape containing the indices to extract + Tensor weight // Weight: Embedding matrix of floating point type with shape (V, embedding_dim), where V = maximum index + 1 +) { + auto input_shape = input->shape(); + auto weight_shape = weight->shape(); + auto vocab_size = weight_shape[0]; + auto embedding_dim = weight_shape[1]; + + // Assign memory to out variables + auto output_shape = input_shape; + output_shape.push_back(embedding_dim); + Tensor inputs_embeds = Tensor::empty(output_shape, weight->dtype(), weight->device()); + + embedding_(inputs_embeds, input, weight); + return inputs_embeds; +} + +void embedding_(Tensor out, Tensor input, Tensor weight) { + assert(infinicore::DataType::I64 == input->dtype() || (infinicore::DataType::I32 == input->dtype())); + assert(infinicore::Device::Type::CPU == input->device()); + + auto input_shape = input->shape(); + auto weight_shape = weight->shape(); + auto vocab_size = weight_shape[0]; + auto embedding_dim = weight_shape[1]; + + // Calculate the number of token + Size counts = 1; + for (auto &v : input_shape) { + counts *= v; + } + + // the bytes of one token + const Size bytes = dsize(weight->dtype()) * embedding_dim; + auto *weight_ptr = weight->data(); + auto *out_ptr = out->data(); + + // copies + if (weight->device().getType() == Device::Type::CPU) { + if (infinicore::DataType::I64 == input->dtype()) { + const int64_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int64_t idx = input_arr[i]; + assert((idx >= 0) && (idx < vocab_size)); + std::memcpy(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } else if (infinicore::DataType::I32 == input->dtype()) { + const int32_t *input_arr = reinterpret_cast(input->data()); + + for (Size i = 0; i < counts; ++i) { + int32_t idx = input_arr[i]; + assert((idx >= 0) && (idx < vocab_size)); + std::memcpy(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } + + } else { + if (infinicore::DataType::I64 == input->dtype()) { + const int64_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int64_t idx = input_arr[i]; + assert((idx >= 0) && (idx < vocab_size)); + context::memcpyD2D(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } else if (infinicore::DataType::I32 == input->dtype()) { + const int32_t *input_arr = reinterpret_cast(input->data()); + for (Size i = 0; i < counts; ++i) { + int32_t idx = input_arr[i]; + assert((idx >= 0) && (idx < vocab_size)); + context::memcpyD2D(out_ptr + i * bytes, + weight_ptr + idx * bytes, + bytes); + } + } + } +} + +} // namespace infinicore::op diff --git a/src/infinicore/ops/rope/rope.cc b/src/infinicore/ops/rope/rope.cc index 5d3389d22..1961cc3dd 100644 --- a/src/infinicore/ops/rope/rope.cc +++ b/src/infinicore/ops/rope/rope.cc @@ -9,7 +9,7 @@ common::OpDispatcher &RoPE::dispatcher() { return dispatcher_; }; -void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { +void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { auto device_type = context::getDevice().getType(); auto func = dispatcher().lookup(device_type); @@ -17,17 +17,17 @@ void RoPE::execute(Tensor x_out, const Tensor &x, const Tensor &pos, const Tenso throw std::runtime_error("No RoPE implementation found for device type: " + std::to_string(static_cast(device_type))); } - func(x_out, x, pos, sin_cache, cos_cache, algo); + func(x_out, x, pos, sin_table, cos_table, algo); } -void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { - RoPE::execute(x_out, x, pos, sin_cache, cos_cache, algo); +void rope_(Tensor x_out, const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { + RoPE::execute(x_out, x, pos, sin_table, cos_table, algo); } -Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_cache, const Tensor &cos_cache, infinicore::nn::RoPE::Algo algo) { +Tensor rope(const Tensor &x, const Tensor &pos, const Tensor &sin_table, const Tensor &cos_table, infinicore::nn::RoPE::Algo algo) { Shape shape = x->shape(); auto x_out = Tensor::empty(shape, x->dtype(), x->device()); - rope_(x_out, x, pos, sin_cache, cos_cache, algo); + rope_(x_out, x, pos, sin_table, cos_table, algo); return x_out; } diff --git a/src/infinicore/pybind11/ops.hpp b/src/infinicore/pybind11/ops.hpp index 98adb88dd..978defa17 100644 --- a/src/infinicore/pybind11/ops.hpp +++ b/src/infinicore/pybind11/ops.hpp @@ -5,12 +5,14 @@ #include "ops/add.hpp" #include "ops/attention.hpp" #include "ops/causal_softmax.hpp" +#include "ops/embedding.hpp" #include "ops/linear.hpp" #include "ops/matmul.hpp" #include "ops/mul.hpp" #include "ops/random_sample.hpp" #include "ops/rearrange.hpp" #include "ops/rms_norm.hpp" +#include "ops/rope.hpp" #include "ops/silu.hpp" #include "ops/swiglu.hpp" @@ -30,6 +32,8 @@ inline void bind(py::module &m) { bind_rms_norm(m); bind_silu(m); bind_swiglu(m); + bind_rope(m); + bind_embedding(m); } } // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/embedding.hpp b/src/infinicore/pybind11/ops/embedding.hpp new file mode 100644 index 000000000..44e14b61c --- /dev/null +++ b/src/infinicore/pybind11/ops/embedding.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include "infinicore/ops/embedding.hpp" +#include + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_embedding(py::module &m) { + + m.def("embedding", + &op::embedding, + py::arg("input"), + py::arg("weight"), + R"doc(Generate a simple lookup table that looks up embeddings in a fixed dictionary and size..)doc"); + + m.def("embedding_", + &op::embedding_, + py::arg("out"), + py::arg("input"), + py::arg("weight"), + R"doc(In-place, Generate a simple lookup table that looks up embeddings in a fixed dictionary and size..)doc"); +} + +} // namespace infinicore::ops diff --git a/src/infinicore/pybind11/ops/rope.hpp b/src/infinicore/pybind11/ops/rope.hpp new file mode 100644 index 000000000..b3c411a89 --- /dev/null +++ b/src/infinicore/pybind11/ops/rope.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "infinicore/ops/rope.hpp" + +namespace py = pybind11; + +namespace infinicore::ops { + +inline void bind_rope(py::module &m) { + + py::enum_(m, "Algo") + .value("GPT_J", infinicore::nn::RoPE::Algo::GPT_J) + .value("GPT_NEOX", infinicore::nn::RoPE::Algo::GPT_NEOX); + + m.def("rope", + &op::rope, + py::arg("x"), + py::arg("pos"), + py::arg("sin_table"), + py::arg("cos_table"), + py::arg("algo"), + R"doc( Rotary Position Embedding(RoPE).)doc"); + + m.def("rope_", + &op::rope_, + py::arg("x_out"), + py::arg("x"), + py::arg("pos"), + py::arg("sin_table"), + py::arg("cos_table"), + py::arg("algo"), + R"doc(In-place, Rotary Position Embedding(RoPE).)doc"); +} + +} // namespace infinicore::ops diff --git a/test/infinicore/ops/embedding.py b/test/infinicore/ops/embedding.py new file mode 100644 index 000000000..ff6911773 --- /dev/null +++ b/test/infinicore/ops/embedding.py @@ -0,0 +1,132 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.tensor import TensorInitializer +from framework.utils import ( + convert_infinicore_to_torch, + infinicore_tensor_from_torch, + to_torch_dtype, +) + +import infinicore + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== +_TEST_CASES_DATA = [ + # bs, ntok, vocab_size, embedding_dim, type + (1, 5, 32000, 4, infinicore.int64), + (2, 10, 32000, 2048, infinicore.int32), + (1, 5, 10, 10, infinicore.int64), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 0, "rtol": 1e-2}, + infinicore.float32: {"atol": 0, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 0, "rtol": 5e-2}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for Embedding operation. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + bs, ntok = data[0], data[1] + vocab_size, embedding_dim = data[2], data[3] + input_type = data[4] + + input_strides = None + weight_strides = None + + # Determine shapes + input_shape = (bs, ntok) + + weight_shape = (vocab_size, embedding_dim) + + # Check if tensors support in-place operations + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + # Create typed tensor specs + input_spec = TensorSpec.from_tensor( + input_shape, + input_strides, + input_type, + init_mode=TensorInitializer.RANDINT, + low=1, + high=9, + ) + weight_spec = TensorSpec.from_tensor(weight_shape, weight_strides, dtype) + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[input_spec, weight_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"Embedding - OUT_OF_PLACE", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """Embedding operator test with simplified implementation""" + + def __init__(self): + super().__init__("Embedding") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, out=None, **kwargs): + """PyTorch Embedding implementation""" + + return torch.nn.functional.embedding(*args, **kwargs) + + def infinicore_operator(self, input, weight, out=None, **kwargs): + """InfiniCore Embedding implementation""" + + if input.device.type == "cpu": + input_cpu = input + else: + # 将 input的数据 转移到 cpu 上 + torch_reference = torch.zeros( + input.shape, + dtype=to_torch_dtype(input.dtype), + device="cpu" if "cpu" == input.device.type else "cuda", + ) + torch_reference = convert_infinicore_to_torch(input) + torch_reference = torch_reference.contiguous().cpu() + + # 创建cpu的 input + input_cpu = infinicore_tensor_from_torch(torch_reference) + + return infinicore.nn.functional.embedding(input_cpu, weight, out=out) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main() diff --git a/test/infinicore/ops/rope.py b/test/infinicore/ops/rope.py new file mode 100644 index 000000000..62aaaefff --- /dev/null +++ b/test/infinicore/ops/rope.py @@ -0,0 +1,181 @@ +import os +import sys + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import infinicore_tensor_from_torch, is_broadcast +from infinicore.nn.functional import RopeAlgo + +import infinicore + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== + + +_TEST_CASES_DATA = [ + # ntok, num, head_dim, Algo + (1, 1, 64, RopeAlgo.GPT_NEOX), + (5, 32, 64, RopeAlgo.GPT_NEOX), + (1, 1, 128, RopeAlgo.GPT_J), + (10, 1, 64, RopeAlgo.GPT_J), +] + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-3, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-2, "rtol": 1e-3}, + infinicore.bfloat16: {"atol": 1e-2, "rtol": 5e-2}, +} + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for Rope operation. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + ntok, num, head_dim = data[0], data[1], data[2] + algo = data[3] + + # Determine shapes based on batch dimension + out_shape = (ntok, num, head_dim) + x_shape = (ntok, num, head_dim) + sin_table_shape = (ntok, head_dim // 2) + cos_table_shape = (ntok, head_dim // 2) + + # Check if tensors support in-place operations + c_supports_inplace = not is_broadcast(out_shape) + + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + # Create typed tensor specs + out_spec = TensorSpec.from_tensor(out_shape, None, dtype) + x_spec = TensorSpec.from_tensor(x_shape, None, dtype) + sin_table_spec = TensorSpec.from_tensor(sin_table_shape, None, dtype) + cos_table_spec = TensorSpec.from_tensor(cos_table_shape, None, dtype) + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[x_spec, sin_table_spec, cos_table_spec], + kwargs={"algo": algo}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"Rope - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensor + if c_supports_inplace: + test_cases.append( + TestCase( + inputs=[x_spec, sin_table_spec, cos_table_spec], + kwargs={"algo": algo}, + output_spec=out_spec, # Specify the output tensor spec + comparison_target="out", + tolerance=tolerance, + description=f"Rope - INPLACE(out)", + ) + ) + + return test_cases + + +def rotary_embedding(t, sin, cos, algo, *, out=None): + def _torch_rope(sin, cos, t1, t2): + cos = cos.unsqueeze(1) # [seq_len, 1, dh // 2] + sin = sin.unsqueeze(1) # [seq_len, 1, dh // 2] + t_out_1 = t1 * cos - t2 * sin + t_out_2 = t1 * sin + t2 * cos + + return t_out_1, t_out_2 + + ans = t.clone() + + dh = t.shape[-1] + dt = t.dtype + assert dh % 2 == 0, "Embedding dimension must be even." + + if RopeAlgo.GPT_J == algo: + t_even = t[..., 0::2] # [seq_len, n_head, dh // 2] + t_odd = t[..., 1::2] # [seq_len, n_head, dh // 2] + + t_out_even, t_out_odd = _torch_rope(sin, cos, t_even, t_odd) + + ans[..., 0::2] = t_out_even.to(dt) + ans[..., 1::2] = t_out_odd.to(dt) + elif RopeAlgo.GPT_NEOX == algo: + half_dim = dh // 2 + t_first = t[..., :half_dim] + t_second = t[..., half_dim:] + + t_out_first, t_out_second = _torch_rope(sin, cos, t_first, t_second) + + ans[..., :half_dim] = t_out_first.to(dt) + ans[..., half_dim:] = t_out_second.to(dt) + else: + raise KeyError("error Algo ") + + if out is not None: + out.copy_(ans) + return out + return ans + + +class OpTest(BaseOperatorTest): + """Rope operator test with simplified implementation""" + + def __init__(self): + super().__init__("Rope") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, *args, **kwargs): + """PyTorch Rope implementation""" + + return rotary_embedding(*args, **kwargs) + + def infinicore_operator(self, x, sin_table, cos_table, algo, out=None, **kwargs): + """InfiniCore Rope implementation""" + + ntok = x.shape[0] + torch_device = "cpu" + if x.device.type != "cpu": + torch_device = "cuda" + + # 创建 pos_ids的变量 + pos_ids_torch = torch.arange(0, ntok, dtype=torch.int32, device=torch_device) + pos_ids_ref = infinicore_tensor_from_torch(pos_ids_torch) + pos_ids_infini = infinicore.empty( + list(pos_ids_ref.shape), dtype=pos_ids_ref.dtype, device=pos_ids_ref.device + ) + pos_ids_infini.copy_(pos_ids_ref) + + # 计算 + pos_ids = pos_ids_infini + return infinicore.nn.functional.rope( + x, pos_ids, sin_table, cos_table, algo=algo, out=out + ) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()