Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/infinicore/ops/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "../device.hpp"
#include "common/op.hpp"
#include <optional>

namespace infinicore::op {
class Attention {
Expand All @@ -13,4 +14,15 @@ class Attention {

Tensor attention(Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);
void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor v_cache, size_t pos);

Tensor scaled_dot_product_attention(Tensor query,
Tensor key,
Tensor value,
std::optional<float> scale);

void scaled_dot_product_attention_(Tensor out,
Tensor query,
Tensor key,
Tensor value,
std::optional<float> scale);
} // namespace infinicore::op
6 changes: 3 additions & 3 deletions python/infinicore/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from infinicore.nn import (
functional as functional,
)
from infinicore.nn import functional

__all__ = ["functional"]
67 changes: 0 additions & 67 deletions python/infinicore/nn/functional.py

This file was deleted.

3 changes: 3 additions & 0 deletions python/infinicore/nn/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .scaled_dot_product_attention import scaled_dot_product_attention

__all__ = ["scaled_dot_product_attention"]
48 changes: 48 additions & 0 deletions python/infinicore/nn/functional/scaled_dot_product_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from typing import Optional

from infinicore.lib import _infinicore
from infinicore.tensor import Tensor

__all__ = ["scaled_dot_product_attention"]


def scaled_dot_product_attention(
query: Tensor,
key: Tensor,
value: Tensor,
attn_mask: Optional[Tensor] = None,
dropout_p: float = 0.0,
is_causal: bool = False,
scale: Optional[float] = None,
enable_gqa: bool = False,
*,
out=None,
) -> Tensor:
r"""Computes scaled dot product attention on query, key and value tensors."""

assert (attn_mask is None) and (0.0 == dropout_p), "Unsupported parameters."
assert (enable_gqa is True) and (is_causal is True), "Incorrect parameter value."

ntoken = query.shape[-2]
total_token = key.shape[-2]

assert (1 == ntoken and total_token > 1) or (ntoken == total_token), (
"Incorrect parameter value."
)

if out is None:
return Tensor(
_infinicore.scaled_dot_product_attention(
query._underlying, key._underlying, value._underlying, scale
)
)

_infinicore.scaled_dot_product_attention_(
out._underlying,
query._underlying,
key._underlying,
value._underlying,
scale,
)

return out
89 changes: 88 additions & 1 deletion src/infinicore/ops/attention/attention.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "infinicore/ops/attention.hpp"

#include "infinicore/ops/causal_softmax.hpp"
#include "infinicore/ops/gemm.hpp"
#include <cmath>
namespace infinicore::op {

common::OpDispatcher<Attention::schema> &Attention::dispatcher() {
Expand All @@ -25,4 +27,89 @@ void attention_(Tensor out, Tensor q, Tensor k, Tensor v, Tensor k_cache, Tensor
Attention::execute(out, q, k, v, k_cache, v_cache, pos);
}

Tensor scaled_dot_product_attention(Tensor query_states, // [bs, num_attention_heads, ntoken, head_dim]
Tensor key_states, // [bs, num_key_value_heads, total_token, head_dim]
Tensor value_states, // [bs, num_key_value_heads, total_token, head_dim]
std::optional<float> scale) {

auto query_shape = query_states->shape();
auto key_shape = key_states->shape();

Size batch_size = query_shape[0];
Size num_attention_heads = query_shape[1];
Size ntoken = query_shape[2];
Size head_dim = key_shape[3];

Tensor output_values = Tensor::empty({batch_size, num_attention_heads, ntoken, head_dim}, query_states->dtype(), query_states->device());

scaled_dot_product_attention_(output_values, query_states, key_states, value_states, scale);

return output_values;
}

void scaled_dot_product_attention_(Tensor out,
Tensor query_states,
Tensor key_states,
Tensor value_states,
std::optional<float> scale) {

auto query_shape = query_states->shape();
auto key_shape = key_states->shape();

Size batch_size = query_shape[0];
Size num_attention_heads = query_shape[1];
Size ntoken = query_shape[2];

Size num_key_value_heads = key_shape[1];
Size total_token = key_shape[2];
Size head_dim = key_shape[3];

assert(0 == (num_attention_heads % num_key_value_heads));
Size ngroup = num_attention_heads / num_key_value_heads;

float attention_scale{0.0f};
if (scale.has_value()) {
attention_scale = scale.value();
} else {
attention_scale = 1.f / float(sqrt(head_dim));
}

Tensor out_view = out->view({batch_size, num_key_value_heads, ngroup * ntoken, head_dim});
for (Size ib = 0; ib < batch_size; ++ib) {
Tensor q = query_states->narrow({{0, ib, 1}})->view({num_attention_heads, ntoken, head_dim}); // [ num_attention_heads, ntoken, head_dim]
Tensor k = key_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
Tensor v = value_states->narrow({{0, ib, 1}})->view({num_key_value_heads, total_token, head_dim}); // [ num_key_value_heads, total_token, head_dim]
Tensor output_v = out_view->narrow({{0, ib, 1}})->view({num_key_value_heads, ngroup * ntoken, head_dim});
{
/*
输入:
q, [ num_attention_heads, ntoken, head_dim]
k, [ num_key_value_heads, total_token, head_dim]
v, [ num_key_value_heads, total_token, head_dim]
输出:
att_val : {num_key_value_heads, ngroup * ntok, head_dim}
*/

auto q_gemm = q->view({num_key_value_heads, ngroup * ntoken, head_dim}); // => {nkvh, ngroup * seq_len, dh}
auto k_gemm = k->permute({0, 2, 1}); // => { nkvh, dh, total_token}
auto v_gemm = v; // => { nkvh, total_token, dh}

// qk_score : => {nkvh, ngroup * ntoken, total_token}
Tensor qk_score = gemm(q_gemm, // {nkvh, ngroup * ntoken, dh}
k_gemm, // {nkvh, dh, total_token}
attention_scale, 0.f);

// softmax

auto qk_softmax = qk_score->view({num_attention_heads, ntoken, total_token});
causal_softmax_(qk_softmax, qk_softmax);

// values
gemm_(output_v, // {nkvh, ngroup * ntoken, dh}
qk_score, // {nkvh, ngroup * ntoken, total_token}
v_gemm, // { nkvh, total_token, dh}
1.0f, 0.0f);
}
}
}
} // namespace infinicore::op
43 changes: 43 additions & 0 deletions src/infinicore/pybind11/ops/attention.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,32 @@

namespace py = pybind11;

namespace infinicore::ops::attention {
Tensor py_scaled_dot_product_attention(Tensor query,
Tensor key,
Tensor value,
pybind11::object scale) {
std::optional<float> scale_float = std::nullopt;
if (!scale.is_none()) {
scale_float = scale.cast<float>();
}
return op::scaled_dot_product_attention(query, key, value, scale_float);
}

void py_scaled_dot_product_attention_(Tensor out,
Tensor query,
Tensor key,
Tensor value,
pybind11::object scale) {
std::optional<float> scale_float = std::nullopt;
if (!scale.is_none()) {
scale_float = scale.cast<float>();
}
op::scaled_dot_product_attention_(out, query, key, value, scale_float);
}

} // namespace infinicore::ops::attention

namespace infinicore::ops {

inline void bind_attention(py::module &m) {
Expand Down Expand Up @@ -51,6 +77,23 @@ inline void bind_attention(py::module &m) {
v_cache: Value cache tensor
pos: Current position in the sequence
)doc");

m.def("scaled_dot_product_attention",
&attention::py_scaled_dot_product_attention,
py::arg("query"),
py::arg("key"),
py::arg("value"),
py::arg("scale") = py::none(),
R"doc(Computes scaled dot product attention on query, key and value tensors)doc");

m.def("scaled_dot_product_attention_",
&attention::py_scaled_dot_product_attention_,
py::arg("out"),
py::arg("query"),
py::arg("key"),
py::arg("value"),
py::arg("scale") = py::none(),
R"doc(In-place, Computes scaled dot product attention on query, key and value tensors)doc");
}

} // namespace infinicore::ops
Loading