|
| 1 | +// Copyright (C) 2018-2025 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#pragma once |
| 6 | + |
| 7 | +#include <cmath> |
| 8 | +#include <cstddef> |
| 9 | +#include <memory> |
| 10 | +#include <queue> |
| 11 | + |
| 12 | +#include "openvino/reference/matmul.hpp" |
| 13 | +#include "openvino/reference/normalize_l2.hpp" |
| 14 | +#include "openvino/reference/reduce_mean.hpp" |
| 15 | +#include "openvino/reference/slice.hpp" |
| 16 | +#include "openvino/runtime/tensor.hpp" |
| 17 | + |
| 18 | +namespace ov::reference { |
| 19 | + |
| 20 | + |
| 21 | +/** @brief Reference implementation of the XAttention sparse attention prefill mechanism |
| 22 | + * (https://arxiv.org/abs/2503.16428) */ |
| 23 | +template <typename T> |
| 24 | +class AdaptiveRKVDiversityCalculator { |
| 25 | +public: |
| 26 | + /** @param threshold Defines a threshold for introduced block sparsity - XAttention attempts to preserve the |
| 27 | + * smallest subset of attention score matrix blocks so that the ratio of the attention score sum to the total sum of |
| 28 | + * attention score matrix elements is no less than `threshold`. In other words, `threshold` defines a fraction of |
| 29 | + * the attention score mass which is to be preserved by most "important" blocks. Valid range is 0.0-1.0, with 0.0 |
| 30 | + * corresponding to 0% of the blocks retained, and 1.0 corresponding to 100% of the blocks retained. |
| 31 | + * @param block_size The size of blocks into which the attention score matrix [num_heads, query_token_dimension, |
| 32 | + * key_token_dimension] will be subdivided for purposes of determining the subset of the most important blocks |
| 33 | + * according to `threshold`. This subdivision occurs on query and key dimensions of the attention score matrix with |
| 34 | + * the same granularity, i.e. the resulting blocks have equal size on both dimensions. Essentially `block_size` |
| 35 | + * defines the granularity of the eventual sparse attention computations. Must be a multiple of `stride`. |
| 36 | + * @param stride The stride at which the full attention matrix is subsampled in a block-antidiagonal fashion to |
| 37 | + * estimate the block importance. Note that the full attention matrix is not computed, instead the original query |
| 38 | + * and key matrices are reshaped appropriately so that only the necessary elements are computed. Ideally, the |
| 39 | + * computational complexity of the entire block estimation operation is `stride` times lower than the full attention |
| 40 | + * matrix computation. |
| 41 | + * */ |
| 42 | + AdaptiveRKVDiversityCalculator(size_t start_size, size_t eviction_size, size_t block_size) |
| 43 | + : m_start_size(start_size), |
| 44 | + m_eviction_size(eviction_size), |
| 45 | + m_block_size(block_size) { |
| 46 | + OPENVINO_ASSERT(start_size % block_size == 0); |
| 47 | + OPENVINO_ASSERT(eviction_size % block_size == 0); |
| 48 | + } |
| 49 | + |
| 50 | + /** Divides the input rank-3 tensor into blocks along last two dimensions, performs the addition of the values |
| 51 | + * inside each block and outputs each block sum into corresponding positions in the output tensor downsampled along |
| 52 | + * the same dimensions. The output tensor dimensions are such that the query and key token dimensions are |
| 53 | + * downsampled by `block_size` when compared to the *original* query and key tensors. |
| 54 | + * @param attention_scores_data Pointer to the attention score input. |
| 55 | + * @param attention_score_shape Shape of the attention score input tensor. Expected shape is [num_heads, |
| 56 | + * num_query_tokens / stride, num_key_tokens / stride], where `num_query_tokens` and `num_key_tokens` must be |
| 57 | + * multiples of `block_size`. |
| 58 | + * @param out Pointer to the output tensor data (block sums) |
| 59 | + * @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size, |
| 60 | + * num_key_tokens / block_size]. |
| 61 | + */ |
| 62 | + void fill_diagonal_(const T* in_out, |
| 63 | + const Shape& in_out_shape, |
| 64 | + T val) { |
| 65 | + OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim] |
| 66 | + OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); // [num_heads, token_dim, token_dim] |
| 67 | + |
| 68 | + |
| 69 | + for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) { |
| 70 | + size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2]; |
| 71 | + for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) { |
| 72 | + size_t diagonal_element_offset = token_dim_idx + token_dim_idx * in_out_shape[1]; |
| 73 | + auto diagonal_element_ptr = in_out + in_head_offset + diagonal_element_offset; |
| 74 | + *diagonal_element_ptr = val; |
| 75 | + } |
| 76 | + } |
| 77 | + } |
| 78 | + |
| 79 | + void fill_low_values_with_zeros_(const T* in_out, |
| 80 | + const Shape& in_out_shape, |
| 81 | + const T* means, |
| 82 | + const Shape& means_shape) { |
| 83 | + OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim] |
| 84 | + OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); |
| 85 | + OPENVINO_ASSERT(means_shape.size() == 2); // [num_heads, token_dim] |
| 86 | + OPENVINO_ASSERT(means_shape[0] == in_out_shape[0]); |
| 87 | + OPENVINO_ASSERT(means_shape[1] == in_out_shape[1]); |
| 88 | + |
| 89 | + for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) { |
| 90 | + size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2]; |
| 91 | + size_t means_head_offset = head_idx * means_shape[1]; |
| 92 | + for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) { |
| 93 | + T mean_val = means[means_head_offset + token_dim_idx]; |
| 94 | + size_t token_offset = token_dim_idx * in_out_shape[2]; |
| 95 | + for (size_t reduced_dim_idx = 0; reduced_dim_idx < in_out_shape[2]; reduced_dim_idx++) { |
| 96 | + size_t target_offset = in_head_offset + token_offset + reduced_dim_idx; |
| 97 | + T filled_val = in_out[target_offset]; |
| 98 | + in_out[target_offset] = filled_val >= mean_val ? filled_val : 0.0; |
| 99 | + } |
| 100 | + } |
| 101 | + } |
| 102 | + } |
| 103 | + |
| 104 | + void block_sum_diversity_values(const T* processed_similarity_token_data, |
| 105 | + const Shape& processed_similarity_token_data_shape, |
| 106 | + T* out, |
| 107 | + const Shape& out_shape) { |
| 108 | + OPENVINO_ASSERT(processed_similarity_token_data_shape.size() == 2); // [token_dim, token_dim] |
| 109 | + OPENVINO_ASSERT(processed_similarity_token_data_shape[0] == processed_similarity_token_data_shape[1]); |
| 110 | + OPENVINO_ASSERT(processed_similarity_token_data_shape[0] % m_block_size == 0); |
| 111 | + |
| 112 | + OPENVINO_ASSERT(out_shape.size() == 2); // [block_dim, token_dim] |
| 113 | + OPENVINO_ASSERT(out_shape[0] == processed_similarity_token_data_shape[0] / m_block_size); |
| 114 | + OPENVINO_ASSERT(out_shape[1] == processed_similarity_token_data_shape[1]); |
| 115 | + |
| 116 | + std::memset(out, 0, out_shape[0] * out_shape[1] * sizeof(T)); |
| 117 | + |
| 118 | + for (size_t out_block_dim_idx = 0; out_block_dim_idx < out_shape[0]; out_block_dim_idx++) { |
| 119 | + size_t out_block_offset = out_block_dim_idx * out_shape[1]; |
| 120 | + for (size_t out_token_dim_idx = 0; out_token_dim_idx < out_shape[1]; out_token_dim_idx++) { |
| 121 | + size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1]; |
| 122 | + for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) { |
| 123 | + size_t source_offset = in_block_offset + in_token_in_block_idx * processed_similarity_token_data_shape[1] + out_token_dim_idx; |
| 124 | + out[out_block_offset + out_token_dim_idx] += processed_similarity_token_data[source_offset]; |
| 125 | + } |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + /** Applies XAttention to the provided query and key matrices, returning the subset of the most important blocks for |
| 131 | + * each attention head, according to the configured block size and threshold, which are to be preserved in the |
| 132 | + * subsequent sparse attention computation. |
| 133 | + * @param query_data Pointer to the query input tensor data |
| 134 | + * @param query_shape Shape of the query input tensor data. Expected shape is [num_heads, num_query_tokens, |
| 135 | + * head_size], where `num_query_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if |
| 136 | + * necessary to do so in the real-world scenario. |
| 137 | + * @param key_data Pointer to the key input tensor data |
| 138 | + * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], |
| 139 | + * where `num_key_tokens` must be a multiple of both `block_size` and `stride`, padded with zeroes if necessary to |
| 140 | + * do so in the real-world scenario. |
| 141 | + * @return A vector of size `num_heads` of sets, each set containing pairs of block indices (.first is the block |
| 142 | + * index along the query dimension, .second - along the key). Each set is the head-specific subset of blocks that |
| 143 | + * must be preserved in the sparse attention computation. Indices are given in units of XAttention-specific |
| 144 | + * `block_size` (as configured), which may differ from the block size in the paged attention implementation. |
| 145 | + */ |
| 146 | + std::vector<std::vector<T>> calculate_block_diversity(const T* key_data, |
| 147 | + const Shape& key_shape) { |
| 148 | + OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] |
| 149 | + OPENVINO_ASSERT(key_shape[1] >= m_block_size * (m_start_size + m_eviction_size)); |
| 150 | + |
| 151 | + // Should be safe to use this in-place |
| 152 | + ov::reference::normalize_l2(key_data, key_data, key_shape, {2}, std::numeric_limits<T>::epsilon()); |
| 153 | + |
| 154 | + Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]}; |
| 155 | + auto cos_similar_buf = allocate_buf(cos_similar_shape); |
| 156 | + ov::reference::matmul(key_data, key_data, cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true); |
| 157 | + |
| 158 | + Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size}; |
| 159 | + auto evictable_subset_buf = allocate_buf(evictable_subset_shape); |
| 160 | + // stops? |
| 161 | + ov::reference::slice(cos_similar_buf.get(), cos_similar_shape, evictable_subset_buf.get(), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2}); |
| 162 | + cos_similar_buf.reset(); |
| 163 | + |
| 164 | + fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0); |
| 165 | + |
| 166 | + Shape means_shape = {key_shape[0], m_eviction_size}; |
| 167 | + auto means_buf = allocate_buf(means_shape); |
| 168 | + ov::reference::reduce_mean(evictable_subset_buf.get(), means_buf.get(), evictable_subset_shape, {2}); |
| 169 | + |
| 170 | + fill_low_values_with_zeros_(evictable_subset_buf.get(), evictable_subset_shape, means_buf.get(), means_shape); |
| 171 | + |
| 172 | + Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size}; |
| 173 | + auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape); |
| 174 | + ov::reference::reduce_mean(evictable_subset_buf.get(), aggregated_token_similarities_buf.get(), evictable_subset_shape, {0}); |
| 175 | + evictable_subset_buf.reset(); |
| 176 | + |
| 177 | + Shape block_diversity_shape = {m_eviction_size / m_block_size, m_eviction_size}; |
| 178 | + auto block_diversity_buf = allocate_buf(block_diversity_shape); |
| 179 | + block_sum_diversity_values(aggregated_token_similarities_buf.get(), aggregated_token_similarities_shape, block_diversity_buf.get(), block_diversity_shape); |
| 180 | + std::vector<std::vector<T>> retval(block_diversity_shape[0], std::vector<T>(block_diversity_shape[1])); |
| 181 | + for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) { |
| 182 | + for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) { |
| 183 | + retval[block_idx][token_idx] = block_diversity_buf.get() + block_idx * block_diversity_shape[1] + token_idx; |
| 184 | + } |
| 185 | + } |
| 186 | + |
| 187 | + return retval; |
| 188 | + } |
| 189 | + |
| 190 | + /** |
| 191 | + * @param shape Shape of a tensor |
| 192 | + * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. |
| 193 | + * */ |
| 194 | + std::shared_ptr<T[]> allocate_buf(const Shape& shape) { |
| 195 | + return std::shared_ptr<T[]>(new T[ov::shape_size(shape)]); |
| 196 | + } |
| 197 | + |
| 198 | + |
| 199 | + size_t m_start_size; |
| 200 | + size_t m_eviction_size; |
| 201 | + size_t m_block_size; |
| 202 | +}; |
| 203 | + |
| 204 | +} // namespace ov::reference |
0 commit comments