Skip to content

Commit f452199

Browse files
committed
Add Adaptive R-KV reference op implementation
1 parent 7de8ff8 commit f452199

File tree

1 file changed

+204
-0
lines changed

1 file changed

+204
-0
lines changed
Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include <cmath>
8+
#include <cstddef>
9+
#include <memory>
10+
#include <queue>
11+
12+
#include "openvino/reference/matmul.hpp"
13+
#include "openvino/reference/normalize_l2.hpp"
14+
#include "openvino/reference/reduce_mean.hpp"
15+
#include "openvino/reference/slice.hpp"
16+
#include "openvino/runtime/tensor.hpp"
17+
18+
namespace ov::reference {
19+
20+
21+
/** @brief Reference implementation of the XAttention sparse attention prefill mechanism
22+
* (https://arxiv.org/abs/2503.16428) */
23+
template <typename T>
24+
class AdaptiveRKVDiversityCalculator {
25+
public:
26+
/** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the
27+
* smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of
28+
* attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of
29+
* the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0
30+
* corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained.
31+
* @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension,
32+
* key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks
33+
* according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with
34+
* the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size`
35+
* defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`.
36+
* @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to
37+
* estimate the block importance. Note that the full attention matrix is not computed, instead the original query
38+
* and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the
39+
* computational complexity of the entire block estimation operation is `stride` times lower than the full attention
40+
* matrix computation.
41+
* */
42+
AdaptiveRKVDiversityCalculator(size_t start_size, size_t eviction_size, size_t block_size)
43+
: m_start_size(start_size),
44+
m_eviction_size(eviction_size),
45+
m_block_size(block_size) {
46+
OPENVINO_ASSERT(start_size % block_size == 0);
47+
OPENVINO_ASSERT(eviction_size % block_size == 0);
48+
}
49+
50+
/** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values
51+
* inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along
52+
* the same dimensions. The output tensor dimensions are such that the query and key token dimensions are
53+
* downsampled by `block_size` when compared to the *original* query and key tensors.
54+
* @param attention_scores_data Pointer to the attention score input.
55+
* @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads,
56+
* num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be
57+
* multiples of `block_size`.
58+
* @param out Pointer to the output tensor data (block sums)
59+
* @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size,
60+
* num_key_tokens / block_size].
61+
*/
62+
void fill_diagonal_(const T* in_out,
63+
const Shape& in_out_shape,
64+
T val) {
65+
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
66+
OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); // [num_heads, token_dim, token_dim]
67+
68+
69+
for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) {
70+
size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
71+
for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) {
72+
size_t diagonal_element_offset = token_dim_idx + token_dim_idx * in_out_shape[1];
73+
auto diagonal_element_ptr = in_out + in_head_offset + diagonal_element_offset;
74+
*diagonal_element_ptr = val;
75+
}
76+
}
77+
}
78+
79+
void fill_low_values_with_zeros_(const T* in_out,
80+
const Shape& in_out_shape,
81+
const T* means,
82+
const Shape& means_shape) {
83+
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
84+
OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]);
85+
OPENVINO_ASSERT(means_shape.size() == 2); // [num_heads, token_dim]
86+
OPENVINO_ASSERT(means_shape[0] == in_out_shape[0]);
87+
OPENVINO_ASSERT(means_shape[1] == in_out_shape[1]);
88+
89+
for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) {
90+
size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
91+
size_t means_head_offset = head_idx * means_shape[1];
92+
for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) {
93+
T mean_val = means[means_head_offset + token_dim_idx];
94+
size_t token_offset = token_dim_idx * in_out_shape[2];
95+
for (size_t reduced_dim_idx = 0; reduced_dim_idx < in_out_shape[2]; reduced_dim_idx++) {
96+
size_t target_offset = in_head_offset + token_offset + reduced_dim_idx;
97+
T filled_val = in_out[target_offset];
98+
in_out[target_offset] = filled_val >= mean_val ? filled_val : 0.0;
99+
}
100+
}
101+
}
102+
}
103+
104+
void block_sum_diversity_values(const T* processed_similarity_token_data,
105+
const Shape& processed_similarity_token_data_shape,
106+
T* out,
107+
const Shape& out_shape) {
108+
OPENVINO_ASSERT(processed_similarity_token_data_shape.size() == 2); // [token_dim, token_dim]
109+
OPENVINO_ASSERT(processed_similarity_token_data_shape[0] == processed_similarity_token_data_shape[1]);
110+
OPENVINO_ASSERT(processed_similarity_token_data_shape[0] % m_block_size == 0);
111+
112+
OPENVINO_ASSERT(out_shape.size() == 2); // [block_dim, token_dim]
113+
OPENVINO_ASSERT(out_shape[0] == processed_similarity_token_data_shape[0] / m_block_size);
114+
OPENVINO_ASSERT(out_shape[1] == processed_similarity_token_data_shape[1]);
115+
116+
std::memset(out, 0, out_shape[0] * out_shape[1] * sizeof(T));
117+
118+
for (size_t out_block_dim_idx = 0; out_block_dim_idx < out_shape[0]; out_block_dim_idx++) {
119+
size_t out_block_offset = out_block_dim_idx * out_shape[1];
120+
for (size_t out_token_dim_idx = 0; out_token_dim_idx < out_shape[1]; out_token_dim_idx++) {
121+
size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1];
122+
for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) {
123+
size_t source_offset = in_block_offset + in_token_in_block_idx * processed_similarity_token_data_shape[1] + out_token_dim_idx;
124+
out[out_block_offset + out_token_dim_idx] += processed_similarity_token_data[source_offset];
125+
}
126+
}
127+
}
128+
}
129+
130+
/** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for
131+
* each attention head, according to the configured block size and threshold, which are to be preserved in the
132+
* subsequent sparse attention computation.
133+
* @param query_data Pointer to the query input tensor data
134+
* @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens,
135+
* head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if
136+
* necessary to do so in the real-world scenario.
137+
* @param key_data Pointer to the key input tensor data
138+
* @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size],
139+
* where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to
140+
* do so in the real-world scenario.
141+
* @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block
142+
* index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that
143+
* must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific
144+
* `block_size` (as configured), which may differ from the block size in the paged attention implementation.
145+
*/
146+
std::vector<std::vector<T>> calculate_block_diversity(const T* key_data,
147+
const Shape& key_shape) {
148+
OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim]
149+
OPENVINO_ASSERT(key_shape[1] >= m_block_size * (m_start_size + m_eviction_size));
150+
151+
// Should be safe to use this in-place
152+
ov::reference::normalize_l2(key_data, key_data, key_shape, {2}, std::numeric_limits<T>::epsilon());
153+
154+
Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]};
155+
auto cos_similar_buf = allocate_buf(cos_similar_shape);
156+
ov::reference::matmul(key_data, key_data, cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
157+
158+
Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size};
159+
auto evictable_subset_buf = allocate_buf(evictable_subset_shape);
160+
// stops?
161+
ov::reference::slice(cos_similar_buf.get(), cos_similar_shape, evictable_subset_buf.get(), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
162+
cos_similar_buf.reset();
163+
164+
fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0);
165+
166+
Shape means_shape = {key_shape[0], m_eviction_size};
167+
auto means_buf = allocate_buf(means_shape);
168+
ov::reference::reduce_mean(evictable_subset_buf.get(), means_buf.get(), evictable_subset_shape, {2});
169+
170+
fill_low_values_with_zeros_(evictable_subset_buf.get(), evictable_subset_shape, means_buf.get(), means_shape);
171+
172+
Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size};
173+
auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape);
174+
ov::reference::reduce_mean(evictable_subset_buf.get(), aggregated_token_similarities_buf.get(), evictable_subset_shape, {0});
175+
evictable_subset_buf.reset();
176+
177+
Shape block_diversity_shape = {m_eviction_size / m_block_size, m_eviction_size};
178+
auto block_diversity_buf = allocate_buf(block_diversity_shape);
179+
block_sum_diversity_values(aggregated_token_similarities_buf.get(), aggregated_token_similarities_shape, block_diversity_buf.get(), block_diversity_shape);
180+
std::vector<std::vector<T>> retval(block_diversity_shape[0], std::vector<T>(block_diversity_shape[1]));
181+
for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) {
182+
for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) {
183+
retval[block_idx][token_idx] = block_diversity_buf.get() + block_idx * block_diversity_shape[1] + token_idx;
184+
}
185+
}
186+
187+
return retval;
188+
}
189+
190+
/**
191+
* @param shape Shape of a tensor
192+
* @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape.
193+
* */
194+
std::shared_ptr<T[]> allocate_buf(const Shape& shape) {
195+
return std::shared_ptr<T[]>(new T[ov::shape_size(shape)]);
196+
}
197+
198+
199+
size_t m_start_size;
200+
size_t m_eviction_size;
201+
size_t m_block_size;
202+
};
203+
204+
} // namespace ov::reference

0 commit comments

Comments
 (0)