Skip to content

Commit 15cc823

Browse files
committed
Format
1 parent 77336b8 commit 15cc823

File tree

2 files changed

+86
-74
lines changed

2 files changed

+86
-74
lines changed

src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,19 @@
1818

1919
namespace ov::reference {
2020

21-
2221
/** @brief Reference implementation of the Adaptive R-KV token diversity calculation mechanism
2322
* (https://arxiv.org/pdf/2505.24133v3) */
2423
template <typename T>
2524
class AdaptiveRKVDiversityCalculator {
2625
public:
2726
/** @param start_size Size, in tokens, of the key cache area that will be ignored for purposes of diversity
28-
* calculation, starting from the beginning of the token dimension ("start area"). Must be a multiple of `block_size`.
27+
* calculation, starting from the beginning of the token dimension ("start area"). Must be a multiple of
28+
* `block_size`.
2929
* @param eviction_size Size, in tokens, from the beginning of the start area, the tokens in which will be
30-
* considred for purposes of diversity calculation ("eviction area"). The rest of the tokens after the eviction area,
31-
* if any, are ignored. Must be a multiple of `block_size`.
32-
* @param block_size Block size of the underlying paged attention implementation. The diversity values will be sum-reduced
33-
* from per-token values to per-block values based on this number of tokens in a block.
30+
* considred for purposes of diversity calculation ("eviction area"). The rest of the tokens after the eviction
31+
* area, if any, are ignored. Must be a multiple of `block_size`.
32+
* @param block_size Block size of the underlying paged attention implementation. The diversity values will be
33+
* sum-reduced from per-token values to per-block values based on this number of tokens in a block.
3434
* */
3535
AdaptiveRKVDiversityCalculator(size_t start_size, size_t eviction_size, size_t block_size)
3636
: m_start_size(start_size),
@@ -46,13 +46,10 @@ class AdaptiveRKVDiversityCalculator {
4646
* @param in_out_shape Shape of the matrix data. Expected shape is [num_heads, token_dim, token_dim].
4747
* @param val Value to fill in the diagonal positions.
4848
*/
49-
void fill_diagonal_(T* in_out,
50-
const Shape& in_out_shape,
51-
T val) {
52-
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
49+
void fill_diagonal_(T* in_out, const Shape& in_out_shape, T val) {
50+
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
5351
OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); // [num_heads, token_dim, token_dim]
5452

55-
5653
for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) {
5754
size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2];
5855
for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) {
@@ -63,19 +60,19 @@ class AdaptiveRKVDiversityCalculator {
6360
}
6461
}
6562

66-
/** For a rank-3 tensor, zeroes out the values that are less than the mean of the values of the corresponding slice at rank 2 (zero-based). Ranks 1 and 2 of the input tensor must be equal. Mean values are computed and provided externally. The operation is done in-place.
63+
/** For a rank-3 tensor, zeroes out the values that are less than the mean of the values of the corresponding slice
64+
* at rank 2 (zero-based). Ranks 1 and 2 of the input tensor must be equal. Mean values are computed and provided
65+
* externally. The operation is done in-place.
6766
* @param in_out Pointer to the tensor data.
6867
* @param in_out_shape Shape of the tensor data. Expected shape is [num_heads, token_dim, token_dim].
69-
* @param means Pointer to the tensor data containing the means of each slice of the `in_out` tensor along its rank 2 (zero-based).
68+
* @param means Pointer to the tensor data containing the means of each slice of the `in_out` tensor along its rank
69+
* 2 (zero-based).
7070
* @param means_shape Shape of the means tensor. Expected shape is [num_heads, token_dim].
7171
*/
72-
void fill_low_values_with_zeros_(T* in_out,
73-
const Shape& in_out_shape,
74-
const T* means,
75-
const Shape& means_shape) {
72+
void fill_low_values_with_zeros_(T* in_out, const Shape& in_out_shape, const T* means, const Shape& means_shape) {
7673
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
7774
OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]);
78-
OPENVINO_ASSERT(means_shape.size() == 2); // [num_heads, token_dim]
75+
OPENVINO_ASSERT(means_shape.size() == 2); // [num_heads, token_dim]
7976
OPENVINO_ASSERT(means_shape[0] == in_out_shape[0]);
8077
OPENVINO_ASSERT(means_shape[1] == in_out_shape[1]);
8178

@@ -96,14 +93,12 @@ class AdaptiveRKVDiversityCalculator {
9693

9794
/** For a square matrix, sums each `block_size`-sized group of matrix rows to produce a row in the output matrix.
9895
* @param in_data Pointer to the matrix data.
99-
* @param in_shape Shape of the matrix data. Expected shape is [token_dim, token_dim], where token_dim must be a multiple of `block_size`.
96+
* @param in_shape Shape of the matrix data. Expected shape is [token_dim, token_dim], where token_dim must be a
97+
* multiple of `block_size`.
10098
* @param out Pointer to the output matrix data.
10199
* @param out_shape Shape of the output matrix. Expected shape is [token_dim / block_size, token_dim].
102100
*/
103-
void block_sum_diversity_values(const T* in_data,
104-
const Shape& in_shape,
105-
T* out,
106-
const Shape& out_shape) {
101+
void block_sum_diversity_values(const T* in_data, const Shape& in_shape, T* out, const Shape& out_shape) {
107102
OPENVINO_ASSERT(in_shape.size() == 2); // [token_dim, token_dim]
108103
OPENVINO_ASSERT(in_shape[0] == in_shape[1]);
109104
OPENVINO_ASSERT(in_shape[0] % m_block_size == 0);
@@ -117,11 +112,11 @@ class AdaptiveRKVDiversityCalculator {
117112
for (size_t out_block_dim_idx = 0; out_block_dim_idx < out_shape[0]; out_block_dim_idx++) {
118113
size_t out_block_offset = out_block_dim_idx * out_shape[1];
119114
for (size_t out_token_dim_idx = 0; out_token_dim_idx < out_shape[1]; out_token_dim_idx++) {
120-
size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1];
121-
for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) {
122-
size_t source_offset = in_block_offset + in_token_in_block_idx * in_shape[1] + out_token_dim_idx;
123-
out[out_block_offset + out_token_dim_idx] -= in_data[source_offset];
124-
}
115+
size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1];
116+
for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) {
117+
size_t source_offset = in_block_offset + in_token_in_block_idx * in_shape[1] + out_token_dim_idx;
118+
out[out_block_offset + out_token_dim_idx] -= in_data[source_offset];
119+
}
125120
}
126121
}
127122
}
@@ -131,37 +126,54 @@ class AdaptiveRKVDiversityCalculator {
131126
* that the 1-st rank is left unaggregated when compared to the full diversity calculation algorithm. The reason
132127
* for this is as follows. The final per-block diversity value computation relies on knowing the subset of blocks
133128
* in the eviction area that will be retained regardless of calculated diversity. This subset must be filtered out
134-
* from the rank-1 dimension when performing reduce-mean in the original algorithm to get 1 diversity value per block
135-
* in the eviction area. Due to implementation specifics the paged attention kernel does not know ahead of time which
136-
* blocks will be "retained" - this information is only available on the openvino.genai level after the PA kernel has executed.
137-
* Therefore the PA kernel will provide raw per-token values on the rank 1 of the returned diversity value matrix and delegatei
138-
* the final reduce-mean and filtering to the openvino.genai level.
129+
* from the rank-1 dimension when performing reduce-mean in the original algorithm to get 1 diversity value per
130+
* block in the eviction area. Due to implementation specifics the paged attention kernel does not know ahead of
131+
* time which blocks will be "retained" - this information is only available on the openvino.genai level after the
132+
* PA kernel has executed. Therefore the PA kernel will provide raw per-token values on the rank 1 of the returned
133+
* diversity value matrix and delegatei the final reduce-mean and filtering to the openvino.genai level.
139134
* @param key_data Pointer to the key cache tensor data
140135
* @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size],
141136
* where `num_key_tokens` must be no less than `start_size + eviction_size`.
142-
* @return A rank-2 matrix in the std::vector representation with dimensions [eviction_size / block_size, eviction_size] containing
143-
* the diversity values. The values are expected to be further mean-reduced along rank 1 (zero-based) at the point in time when the
144-
* subset of blocks to be exclusively retained is known.
137+
* @return A rank-2 matrix in the std::vector representation with dimensions [eviction_size / block_size,
138+
* eviction_size] containing the diversity values. The values are expected to be further mean-reduced along rank 1
139+
* (zero-based) at the point in time when the subset of blocks to be exclusively retained is known.
145140
*/
146-
std::vector<std::vector<T>> calculate_block_diversity(const T* key_data,
147-
const Shape& key_shape) {
148-
OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim]
141+
std::vector<std::vector<T>> calculate_block_diversity(const T* key_data, const Shape& key_shape) {
142+
OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim]
149143
OPENVINO_ASSERT(key_shape[1] >= m_start_size + m_eviction_size);
150144

151-
152145
auto normalized_key_data_buf = allocate_buf(key_shape);
153146
// Should be safe to use this in-place
154-
ov::reference::normalize_l2(key_data, normalized_key_data_buf.get(), key_shape, {2}, std::numeric_limits<float>::epsilon(), ov::op::EpsMode::ADD);
147+
ov::reference::normalize_l2(key_data,
148+
normalized_key_data_buf.get(),
149+
key_shape,
150+
{2},
151+
std::numeric_limits<float>::epsilon(),
152+
ov::op::EpsMode::ADD);
155153

156154
Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]};
157155
auto cos_similar_buf = allocate_buf(cos_similar_shape);
158-
ov::reference::matmul(normalized_key_data_buf.get(), normalized_key_data_buf.get(), cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
156+
ov::reference::matmul(normalized_key_data_buf.get(),
157+
normalized_key_data_buf.get(),
158+
cos_similar_buf.get(),
159+
key_shape,
160+
key_shape,
161+
cos_similar_shape,
162+
/* transpose_arg0 = */ false,
163+
/* transpose_arg1 = */ true);
159164
normalized_key_data_buf.reset();
160165

161166
Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size};
162167
auto evictable_subset_buf = allocate_buf(evictable_subset_shape);
163168
// stops?
164-
ov::reference::slice(reinterpret_cast<char*>(cos_similar_buf.get()), cos_similar_shape, reinterpret_cast<char*>(evictable_subset_buf.get()), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
169+
ov::reference::slice(reinterpret_cast<char*>(cos_similar_buf.get()),
170+
cos_similar_shape,
171+
reinterpret_cast<char*>(evictable_subset_buf.get()),
172+
evictable_subset_shape,
173+
sizeof(T),
174+
/* starts = */ {m_start_size, m_start_size},
175+
/* steps = */ {1, 1},
176+
/* axes = */ {1, 2});
165177
cos_similar_buf.reset();
166178

167179
fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0);
@@ -175,12 +187,18 @@ class AdaptiveRKVDiversityCalculator {
175187

176188
Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size};
177189
auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape);
178-
ov::reference::reduce_mean(evictable_subset_buf.get(), aggregated_token_similarities_buf.get(), evictable_subset_shape, {0});
190+
ov::reference::reduce_mean(evictable_subset_buf.get(),
191+
aggregated_token_similarities_buf.get(),
192+
evictable_subset_shape,
193+
{0});
179194
evictable_subset_buf.reset();
180195

181196
Shape block_diversity_shape = {m_eviction_size / m_block_size, m_eviction_size};
182197
auto block_diversity_buf = allocate_buf(block_diversity_shape);
183-
block_sum_diversity_values(aggregated_token_similarities_buf.get(), aggregated_token_similarities_shape, block_diversity_buf.get(), block_diversity_shape);
198+
block_sum_diversity_values(aggregated_token_similarities_buf.get(),
199+
aggregated_token_similarities_shape,
200+
block_diversity_buf.get(),
201+
block_diversity_shape);
184202
std::vector<std::vector<T>> retval(block_diversity_shape[0], std::vector<T>(block_diversity_shape[1]));
185203
for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) {
186204
for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) {
@@ -199,7 +217,6 @@ class AdaptiveRKVDiversityCalculator {
199217
return std::shared_ptr<T[]>(new T[ov::shape_size(shape)]);
200218
}
201219

202-
203220
size_t m_start_size;
204221
size_t m_eviction_size;
205222
size_t m_block_size;

src/core/tests/reference/adaptive_rkv_diversity.cpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ size_t DEFAULT_BLOCK_SIZE = 2;
1212
size_t DEFAULT_START_SIZE = 2;
1313
size_t DEFAULT_EVICTION_SIZE = 10;
1414

15-
1615
TEST(AdaptiveRKVE2ESmokeTest, CalculatesDiversityWithoutThrowing) {
1716
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE,
1817
DEFAULT_EVICTION_SIZE,
@@ -24,7 +23,6 @@ TEST(AdaptiveRKVE2ESmokeTest, CalculatesDiversityWithoutThrowing) {
2423
EXPECT_NO_THROW(calculator.calculate_block_diversity(mock_data.data(), mock_shape));
2524
};
2625

27-
2826
struct FillDiagonalTestData {
2927
ov::Shape in_shape;
3028
std::vector<double> in_data;
@@ -33,10 +31,9 @@ struct FillDiagonalTestData {
3331

3432
using AdaptiveRKVDiversityFillDiagonalTest = ::testing::TestWithParam<FillDiagonalTestData>;
3533

36-
std::vector<FillDiagonalTestData> FILL_DIAGONAL_TEST_CASES = {
37-
{
38-
{2, 4, 4},
39-
// clang-format off
34+
std::vector<FillDiagonalTestData> FILL_DIAGONAL_TEST_CASES = {{
35+
{2, 4, 4},
36+
// clang-format off
4037
{
4138
3.144, 8.512, 8.518, -8.386,
4239
7.889, -5.721, 5.507, 4.295,
@@ -48,9 +45,9 @@ std::vector<FillDiagonalTestData> FILL_DIAGONAL_TEST_CASES = {
4845
3.469, 7.633, 7.244, -6.844,
4946
-7.173, 4.450, 6.705, -7.035
5047
},
51-
// clang-format on
48+
// clang-format on
5249

53-
// clang-format off
50+
// clang-format off
5451
{
5552
42.00, 8.512, 8.518, -8.386,
5653
7.889, 42.00, 5.507, 4.295,
@@ -62,21 +59,20 @@ std::vector<FillDiagonalTestData> FILL_DIAGONAL_TEST_CASES = {
6259
3.469, 7.633, 42.00, -6.844,
6360
-7.173, 4.450, 6.705, 42.00
6461
},
65-
// clang-format on
66-
}
67-
};
62+
// clang-format on
63+
}};
6864

6965
TEST_P(AdaptiveRKVDiversityFillDiagonalTest, FillsDiagonal) {
7066
auto test_struct = GetParam();
7167
ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape));
7268
ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape));
7369

74-
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE, DEFAULT_EVICTION_SIZE, DEFAULT_BLOCK_SIZE);
70+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE,
71+
DEFAULT_EVICTION_SIZE,
72+
DEFAULT_BLOCK_SIZE);
7573

7674
std::vector<double> test_out_data = test_struct.in_data;
77-
calculator.fill_diagonal_(test_out_data.data(),
78-
test_struct.in_shape,
79-
42.0);
75+
calculator.fill_diagonal_(test_out_data.data(), test_struct.in_shape, 42.0);
8076
EXPECT_EQ(test_out_data, test_struct.ref_out_data);
8177
}
8278

@@ -152,7 +148,10 @@ TEST_P(AdaptiveRKVFillLowValuesWithZerosTest, FillsLowValuesWithZero) {
152148
DEFAULT_EVICTION_SIZE,
153149
DEFAULT_BLOCK_SIZE);
154150
std::vector<double> test_out_data = test_struct.in_data;
155-
calculator.fill_low_values_with_zeros_(test_out_data.data(), test_struct.in_shape, test_struct.means.data(), test_struct.means_shape);
151+
calculator.fill_low_values_with_zeros_(test_out_data.data(),
152+
test_struct.in_shape,
153+
test_struct.means.data(),
154+
test_struct.means_shape);
156155

157156
EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_out_data));
158157
}
@@ -161,7 +160,6 @@ INSTANTIATE_TEST_SUITE_P(VariousInputs,
161160
AdaptiveRKVFillLowValuesWithZerosTest,
162161
::testing::ValuesIn(FILL_LOW_VALUES_WITH_ZEROS_TEST_CASES));
163162

164-
165163
struct BlockSumTestData {
166164
ov::Shape in_shape;
167165
std::vector<double> in_data;
@@ -409,17 +407,14 @@ std::vector<E2EDiversityTestData> E2E_DIVERSITY_TEST_CASES = {
409407
-9.120, -7.228, -9.186, 3.202,
410408
-9.304, -0.401, -5.287, 6.834
411409
},
412-
// clang-format on
410+
// clang-format on
413411

414412
/* start_size = */ 2,
415413
/* eviction_size = */ 6,
416-
{
417-
{-0.237145, -0.237145, -0.352696, -0.487902, -0.072365, -0.707192},
418-
{-0.334657, -0.505941, 0, 0.036135, -0.634881,-0.490221},
419-
{-0.380811, -0.398746801, -0.432080003, -0.693021748, 0, 0.067216441}
420-
},
421-
}
422-
};
414+
{{-0.237145, -0.237145, -0.352696, -0.487902, -0.072365, -0.707192},
415+
{-0.334657, -0.505941, 0, 0.036135, -0.634881, -0.490221},
416+
{-0.380811, -0.398746801, -0.432080003, -0.693021748, 0, 0.067216441}},
417+
}};
423418

424419
TEST_P(AdaptiveRKVE2EDiversityTest, CalculatesDiversityCorrectly) {
425420
auto test_struct = GetParam();
@@ -434,10 +429,10 @@ TEST_P(AdaptiveRKVE2EDiversityTest, CalculatesDiversityCorrectly) {
434429
}
435430

436431
for (size_t i = 0; i < test_diversity.size(); i++) {
437-
EXPECT_THAT(test_diversity[i], ::testing::Pointwise(::testing::DoubleNear(1e-6), test_struct.ref_diversity_data[i]));
432+
EXPECT_THAT(test_diversity[i],
433+
::testing::Pointwise(::testing::DoubleNear(1e-6), test_struct.ref_diversity_data[i]));
438434
}
439-
440435
};
441436

442437
INSTANTIATE_TEST_SUITE_P(VariousInputs, AdaptiveRKVE2EDiversityTest, ::testing::ValuesIn(E2E_DIVERSITY_TEST_CASES));
443-
}
438+
} // namespace adaptive_rkv_test

0 commit comments

Comments
 (0)