From 0609cd9c4c38d04a3fca0cdad78abb6a02f12d64 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sun, 16 Nov 2025 22:31:14 +0800 Subject: [PATCH] =?UTF-8?q?issue/586=20-=20=E6=B7=BB=E5=8A=A0python?= =?UTF-8?q?=E7=9A=84scaled=5Fdot=5Fproduct=5Fattention=E7=9A=84=E5=AE=9E?= =?UTF-8?q?=E7=8E=B0=E5=92=8C=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/infinicore/ops/attention.hpp | 12 ++ python/infinicore/nn/__init__.py | 6 +- python/infinicore/nn/functional.py | 67 --------- python/infinicore/nn/functional/__init__.py | 3 + .../scaled_dot_product_attention.py | 48 ++++++ src/infinicore/ops/attention/attention.cc | 89 ++++++++++- src/infinicore/pybind11/ops/attention.hpp | 43 ++++++ .../ops/scaled_dot_product_attention.py | 138 ++++++++++++++++++ 8 files changed, 335 insertions(+), 71 deletions(-) delete mode 100644 python/infinicore/nn/functional.py create mode 100644 python/infinicore/nn/functional/__init__.py create mode 100644 python/infinicore/nn/functional/scaled_dot_product_attention.py create mode 100644 test/infinicore/ops/scaled_dot_product_attention.py diff --git a/include/infinicore/ops/attention.hpp b/include/infinicore/ops/attention.hpp index 1bc447c77..4b19f3e6b 100644 --- a/include/infinicore/ops/attention.hpp +++ b/include/infinicore/ops/attention.hpp @@ -2,6 +2,7 @@ #include "../device.hpp" #include "common/op.hpp" +#include namespace infinicore::op { class Attention { @@ -13,4 +14,15 @@ class Attention { Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos); + +Tensor scaled_dot_product_attention(Tensor query, + Tensor key, + Tensor value, + std::optional scale); + +void scaled_dot_product_attention_(Tensor out, + Tensor query, + Tensor key, + Tensor value, + std::optional scale); } // namespace infinicore::op diff --git a/python/infinicore/nn/__init__.py b/python/infinicore/nn/__init__.py index 9a091e628..e08b88af4 100644 --- a/python/infinicore/nn/__init__.py +++ b/python/infinicore/nn/__init__.py @@ -1,3 +1,3 @@ -from infinicore.nn import ( - functional as functional, -) +from infinicore.nn import functional + +__all__ = ["functional"] diff --git a/python/infinicore/nn/functional.py b/python/infinicore/nn/functional.py deleted file mode 100644 index ea969052c..000000000 --- a/python/infinicore/nn/functional.py +++ /dev/null @@ -1,67 +0,0 @@ -import infinicore -from infinicore.lib import _infinicore -from infinicore.tensor import Tensor - -__all__ = ["causal_softmax", "rms_norm", "silu", "swiglu"] - - -def causal_softmax(input: Tensor, out=None) -> Tensor: - r"""Apply a causal softmax function.""" - - if out is None: - return Tensor(_infinicore.causal_softmax(input._underlying)) - - _infinicore.causal_softmax_(out._underlying, input._underlying) - - return out - - -def rms_norm( - input: Tensor, - normalized_shape: list[int], - weight: Tensor, - eps: float = 1e-5, - *, - out=None, -) -> Tensor: - r"""Apply Root Mean Square Layer Normalization.""" - - assert normalized_shape == weight.shape, ( - "normalized_shape does not match weight.shape." - ) - - if out is None: - return Tensor(_infinicore.rms_norm(input._underlying, weight._underlying, eps)) - - _infinicore.rms_norm_(out._underlying, input._underlying, weight._underlying, eps) - - return out - - -def silu(input: Tensor, inplace: bool = False, *, out=None) -> Tensor: - r"""Apply the Sigmoid Linear Unit (SiLU) function, element-wise.""" - - if infinicore.use_ntops and input.device.type in ("cuda", "musa") and out is None: - return infinicore.ntops.torch.silu(input, inplace=inplace) - - if inplace: - _infinicore.silu_(input._underlying, input._underlying) - return input - - if out is None: - return Tensor(_infinicore.silu(input._underlying)) - - _infinicore.silu_(out._underlying, input._underlying) - - return out - - -def swiglu(input: Tensor, other: Tensor, *, out=None): - r"""Apply the Swish-Gated Linear Unit (SwiGLU) function, element-wise.""" - - if out is None: - return Tensor(_infinicore.swiglu(input._underlying, other._underlying)) - - _infinicore.swiglu_(out._underlying, input._underlying, other._underlying) - - return out diff --git a/python/infinicore/nn/functional/__init__.py b/python/infinicore/nn/functional/__init__.py new file mode 100644 index 000000000..e76eb8bb5 --- /dev/null +++ b/python/infinicore/nn/functional/__init__.py @@ -0,0 +1,3 @@ +from .scaled_dot_product_attention import scaled_dot_product_attention + +__all__ = ["scaled_dot_product_attention"] diff --git a/python/infinicore/nn/functional/scaled_dot_product_attention.py b/python/infinicore/nn/functional/scaled_dot_product_attention.py new file mode 100644 index 000000000..7aa0172d0 --- /dev/null +++ b/python/infinicore/nn/functional/scaled_dot_product_attention.py @@ -0,0 +1,48 @@ +from typing import Optional + +from infinicore.lib import _infinicore +from infinicore.tensor import Tensor + +__all__ = ["scaled_dot_product_attention"] + + +def scaled_dot_product_attention( + query: Tensor, + key: Tensor, + value: Tensor, + attn_mask: Optional[Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + *, + out=None, +) -> Tensor: + r"""Computes scaled dot product attention on query, key and value tensors.""" + + assert (attn_mask is None) and (0.0 == dropout_p), "Unsupported parameters." + assert (enable_gqa is True) and (is_causal is True), "Incorrect parameter value." + + ntoken = query.shape[-2] + total_token = key.shape[-2] + + assert (1 == ntoken and total_token > 1) or (ntoken == total_token), ( + "Incorrect parameter value." + ) + + if out is None: + return Tensor( + _infinicore.scaled_dot_product_attention( + query._underlying, key._underlying, value._underlying, scale + ) + ) + + _infinicore.scaled_dot_product_attention_( + out._underlying, + query._underlying, + key._underlying, + value._underlying, + scale, + ) + + return out diff --git a/src/infinicore/ops/attention/attention.cc b/src/infinicore/ops/attention/attention.cc index bf4fd8203..5ff97b6b7 100644 --- a/src/infinicore/ops/attention/attention.cc +++ b/src/infinicore/ops/attention/attention.cc @@ -1,5 +1,7 @@ #include "infinicore/ops/attention.hpp" - +#include "infinicore/ops/causal_softmax.hpp" +#include "infinicore/ops/gemm.hpp" +#include namespace infinicore::op { common::OpDispatcher &Attention::dispatcher() { @@ -25,4 +27,89 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor Attention::execute(out, q, k, v, k_cache, v_cache, pos); } +Tensor scaled_dot_product_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim] + Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim] + Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim] + std::optional scale) { + + auto query_shape = query_states->shape(); + auto key_shape = key_states->shape(); + + Size batch_size = query_shape[0]; + Size num_attention_heads = query_shape[1]; + Size ntoken = query_shape[2]; + Size head_dim = key_shape[3]; + + Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device()); + + scaled_dot_product_attention_(output_values, query_states, key_states, value_states, scale); + + return output_values; +} + +void scaled_dot_product_attention_(Tensor out, + Tensor query_states, + Tensor key_states, + Tensor value_states, + std::optional scale) { + + auto query_shape = query_states->shape(); + auto key_shape = key_states->shape(); + + Size batch_size = query_shape[0]; + Size num_attention_heads = query_shape[1]; + Size ntoken = query_shape[2]; + + Size num_key_value_heads = key_shape[1]; + Size total_token = key_shape[2]; + Size head_dim = key_shape[3]; + + assert(0 == (num_attention_heads % num_key_value_heads)); + Size ngroup = num_attention_heads / num_key_value_heads; + + float attention_scale{0.0f}; + if (scale.has_value()) { + attention_scale = scale.value(); + } else { + attention_scale = 1.f / float(sqrt(head_dim)); + } + + Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim}); + for (Size ib = 0; ib < batch_size; ++ib) { + Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim] + Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim] + Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim] + Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim}); + { + /* + 输入: + q, [ num_attention_heads, ntoken, head_dim] + k, [ num_key_value_heads, total_token, head_dim] + v, [ num_key_value_heads, total_token, head_dim] + 输出: + att_val : {num_key_value_heads, ngroup * ntok, head_dim} + */ + + auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh} + auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token} + auto v_gemm = v; // => { nkvh, total_token, dh} + + // qk_score : => {nkvh, ngroup * ntoken, total_token} + Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh} + k_gemm, // {nkvh, dh, total_token} + attention_scale, 0.f); + + // softmax + + auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token}); + causal_softmax_(qk_softmax, qk_softmax); + + // values + gemm_(output_v, // {nkvh, ngroup * ntoken, dh} + qk_score, // {nkvh, ngroup * ntoken, total_token} + v_gemm, // { nkvh, total_token, dh} + 1.0f, 0.0f); + } + } +} } // namespace infinicore::op diff --git a/src/infinicore/pybind11/ops/attention.hpp b/src/infinicore/pybind11/ops/attention.hpp index 4af2d5f74..41373d507 100644 --- a/src/infinicore/pybind11/ops/attention.hpp +++ b/src/infinicore/pybind11/ops/attention.hpp @@ -6,6 +6,32 @@ namespace py = pybind11; +namespace infinicore::ops::attention { +Tensor py_scaled_dot_product_attention(Tensor query, + Tensor key, + Tensor value, + pybind11::object scale) { + std::optional scale_float = std::nullopt; + if (!scale.is_none()) { + scale_float = scale.cast(); + } + return op::scaled_dot_product_attention(query, key, value, scale_float); +} + +void py_scaled_dot_product_attention_(Tensor out, + Tensor query, + Tensor key, + Tensor value, + pybind11::object scale) { + std::optional scale_float = std::nullopt; + if (!scale.is_none()) { + scale_float = scale.cast(); + } + op::scaled_dot_product_attention_(out, query, key, value, scale_float); +} + +} // namespace infinicore::ops::attention + namespace infinicore::ops { inline void bind_attention(py::module &m) { @@ -51,6 +77,23 @@ inline void bind_attention(py::module &m) { v_cache: Value cache tensor pos: Current position in the sequence )doc"); + + m.def("scaled_dot_product_attention", + &attention::py_scaled_dot_product_attention, + py::arg("query"), + py::arg("key"), + py::arg("value"), + py::arg("scale") = py::none(), + R"doc(Computes scaled dot product attention on query, key and value tensors)doc"); + + m.def("scaled_dot_product_attention_", + &attention::py_scaled_dot_product_attention_, + py::arg("out"), + py::arg("query"), + py::arg("key"), + py::arg("value"), + py::arg("scale") = py::none(), + R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc"); } } // namespace infinicore::ops diff --git a/test/infinicore/ops/scaled_dot_product_attention.py b/test/infinicore/ops/scaled_dot_product_attention.py new file mode 100644 index 000000000..703daca09 --- /dev/null +++ b/test/infinicore/ops/scaled_dot_product_attention.py @@ -0,0 +1,138 @@ +import sys +import os + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +import torch +import infinicore +from framework.base import BaseOperatorTest, TensorSpec, TestCase +from framework.runner import GenericTestRunner +from framework.utils import is_broadcast + + +# ============================================================================== +# Operator-specific configuration +# ============================================================================== +_TEST_CASES_DATA = [ + # bs, ntoken, total_token, num_attention_heads, num_key_value_heads, head_dim + (1, 4, 4, 8, 8, 64), + (1, 1, 4, 8, 8, 64), + (4, 16, 16, 32, 8, 64), + (4, 1, 128, 32, 8, 64), +] + + +# Tolerance configuration +_TOLERANCE_MAP = { + infinicore.float16: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.float32: {"atol": 1e-2, "rtol": 1e-2}, + infinicore.bfloat16: {"atol": 5e-2, "rtol": 5e-2}, +} + + +# Data types to test +_TENSOR_DTYPES = [infinicore.float16, infinicore.bfloat16, infinicore.float32] +# _TENSOR_DTYPES = [infinicore.bfloat16] + + +def parse_test_cases(): + """ + Parse test case data and return list of TestCase objects for sdpa operation. + Each test case contains all necessary information for execution and validation. + """ + test_cases = [] + + for data in _TEST_CASES_DATA: + bs = data[0] + ntoken, total_token = data[1], data[2] + num_attention_heads, num_key_value_heads = data[3], data[4] + head_dim = data[5] + + # Determine shapes based on batch dimension + query_shape = (bs, num_attention_heads, ntoken, head_dim) + key_shape = (bs, num_key_value_heads, total_token, head_dim) + value_shape = (bs, num_key_value_heads, total_token, head_dim) + out_shape = (bs, num_attention_heads, ntoken, head_dim) + + # Check if tensors support in-place operations + c_supports_inplace = not is_broadcast(out_shape) + + # Generate test cases for all data types + for dtype in _TENSOR_DTYPES: + tolerance = _TOLERANCE_MAP.get(dtype, {"atol": 0, "rtol": 1e-3}) + + # Create typed tensor specs + query_spec = TensorSpec.from_tensor(query_shape, None, dtype) + key_spec = TensorSpec.from_tensor(key_shape, None, dtype) + value_spec = TensorSpec.from_tensor(value_shape, None, dtype) + out_spec = TensorSpec.from_tensor(out_shape, None, dtype) + + # Test Case 1: Out-of-place (return value) + test_cases.append( + TestCase( + inputs=[query_spec, key_spec, value_spec], + kwargs={}, + output_spec=None, + comparison_target=None, + tolerance=tolerance, + description=f"sdpa - OUT_OF_PLACE", + ) + ) + + # Test Case 2: In-place with explicit output tensor + if c_supports_inplace: + test_cases.append( + TestCase( + inputs=[query_spec, key_spec, value_spec], + kwargs=None, + output_spec=out_spec, # Specify the output tensor spec + comparison_target="out", + tolerance=tolerance, + description=f"sdpa - INPLACE(out)", + ) + ) + + return test_cases + + +class OpTest(BaseOperatorTest): + """sdpa operator test with simplified implementation""" + + def __init__(self): + super().__init__("sdpa") + + def get_test_cases(self): + return parse_test_cases() + + def torch_operator(self, query, key, value, out=None, **kwargs): + """PyTorch sdpa implementation""" + ntoken = query.shape[-2] + total_token = key.shape[-2] + + is_causal = True + if 1 == ntoken and total_token > 1: + is_causal = False + + result = torch.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=is_causal, enable_gqa=True + ) + if out is not None: + out.copy_(result) + return out + return result + + def infinicore_operator(self, query, key, value, out=None, **kwargs): + """InfiniCore sdpa implementation""" + return infinicore.nn.functional.scaled_dot_product_attention( + query, key, value, is_causal=True, enable_gqa=True, out=out + ) + + +def main(): + """Main entry point""" + runner = GenericTestRunner(OpTest) + runner.run_and_exit() + + +if __name__ == "__main__": + main()