Skip to content

Commit 6581f58

Browse files
committed
Add basic tests
1 parent f452199 commit 6581f58

File tree

2 files changed

+308
-8
lines changed

2 files changed

+308
-8
lines changed

src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <queue>
1111

12+
#include "openvino/op/util/attr_types.hpp"
1213
#include "openvino/reference/matmul.hpp"
1314
#include "openvino/reference/normalize_l2.hpp"
1415
#include "openvino/reference/reduce_mean.hpp"
@@ -59,7 +60,7 @@ class AdaptiveRKVDiversityCalculator {
5960
* @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size,
6061
* num_key_tokens / block_size].
6162
*/
62-
void fill_diagonal_(const T* in_out,
63+
void fill_diagonal_(T* in_out,
6364
const Shape& in_out_shape,
6465
T val) {
6566
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
@@ -76,7 +77,7 @@ class AdaptiveRKVDiversityCalculator {
7677
}
7778
}
7879

79-
void fill_low_values_with_zeros_(const T* in_out,
80+
void fill_low_values_with_zeros_(T* in_out,
8081
const Shape& in_out_shape,
8182
const T* means,
8283
const Shape& means_shape) {
@@ -121,7 +122,7 @@ class AdaptiveRKVDiversityCalculator {
121122
size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1];
122123
for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) {
123124
size_t source_offset = in_block_offset + in_token_in_block_idx * processed_similarity_token_data_shape[1] + out_token_dim_idx;
124-
out[out_block_offset + out_token_dim_idx] += processed_similarity_token_data[source_offset];
125+
out[out_block_offset + out_token_dim_idx] -= processed_similarity_token_data[source_offset];
125126
}
126127
}
127128
}
@@ -146,19 +147,22 @@ class AdaptiveRKVDiversityCalculator {
146147
std::vector<std::vector<T>> calculate_block_diversity(const T* key_data,
147148
const Shape& key_shape) {
148149
OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim]
149-
OPENVINO_ASSERT(key_shape[1] >= m_block_size * (m_start_size + m_eviction_size));
150+
OPENVINO_ASSERT(key_shape[1] >= m_start_size + m_eviction_size);
150151

152+
153+
auto normalized_key_data_buf = allocate_buf(key_shape);
151154
// Should be safe to use this in-place
152-
ov::reference::normalize_l2(key_data, key_data, key_shape, {2}, std::numeric_limits<T>::epsilon());
155+
ov::reference::normalize_l2(key_data, normalized_key_data_buf.get(), key_shape, {2}, std::numeric_limits<float>::epsilon(), ov::op::EpsMode::ADD);
153156

154157
Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]};
155158
auto cos_similar_buf = allocate_buf(cos_similar_shape);
156-
ov::reference::matmul(key_data, key_data, cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
159+
ov::reference::matmul(normalized_key_data_buf.get(), normalized_key_data_buf.get(), cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
160+
normalized_key_data_buf.reset();
157161

158162
Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size};
159163
auto evictable_subset_buf = allocate_buf(evictable_subset_shape);
160164
// stops?
161-
ov::reference::slice(cos_similar_buf.get(), cos_similar_shape, evictable_subset_buf.get(), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
165+
ov::reference::slice(reinterpret_cast<char*>(cos_similar_buf.get()), cos_similar_shape, reinterpret_cast<char*>(evictable_subset_buf.get()), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
162166
cos_similar_buf.reset();
163167

164168
fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0);
@@ -168,6 +172,7 @@ class AdaptiveRKVDiversityCalculator {
168172
ov::reference::reduce_mean(evictable_subset_buf.get(), means_buf.get(), evictable_subset_shape, {2});
169173

170174
fill_low_values_with_zeros_(evictable_subset_buf.get(), evictable_subset_shape, means_buf.get(), means_shape);
175+
means_buf.reset();
171176

172177
Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size};
173178
auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape);
@@ -180,7 +185,7 @@ class AdaptiveRKVDiversityCalculator {
180185
std::vector<std::vector<T>> retval(block_diversity_shape[0], std::vector<T>(block_diversity_shape[1]));
181186
for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) {
182187
for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) {
183-
retval[block_idx][token_idx] = block_diversity_buf.get() + block_idx * block_diversity_shape[1] + token_idx;
188+
retval[block_idx][token_idx] = block_diversity_buf[block_idx * block_diversity_shape[1] + token_idx];
184189
}
185190
}
186191

Lines changed: 295 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,295 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <gmock/gmock.h>
6+
#include <gtest/gtest.h>
7+
8+
#include <openvino/reference/adaptive_rkv_diversity.hpp>
9+
10+
namespace adaptive_rkv_test {
11+
size_t DEFAULT_BLOCK_SIZE = 2;
12+
size_t DEFAULT_START_SIZE = 2;
13+
size_t DEFAULT_EVICTION_SIZE = 10;
14+
15+
16+
TEST(AdaptiveRKVE2ESmokeTest, CalculatesDiversityWithoutThrowing) {
17+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE,
18+
DEFAULT_EVICTION_SIZE,
19+
DEFAULT_BLOCK_SIZE);
20+
21+
ov::Shape mock_shape{2, (DEFAULT_START_SIZE + DEFAULT_EVICTION_SIZE) * 2, 8};
22+
std::vector<double> mock_data(ov::shape_size(mock_shape), 1.0);
23+
24+
EXPECT_NO_THROW(calculator.calculate_block_diversity(mock_data.data(), mock_shape));
25+
};
26+
27+
28+
struct FillDiagonalTestData {
29+
ov::Shape in_shape;
30+
std::vector<double> in_data;
31+
std::vector<double> ref_out_data;
32+
};
33+
34+
using AdaptiveRKVDiversityFillDiagonalTest = ::testing::TestWithParam<FillDiagonalTestData>;
35+
36+
std::vector<FillDiagonalTestData> FILL_DIAGONAL_TEST_CASES = {
37+
{
38+
{2, 4, 4},
39+
// clang-format off
40+
{
41+
3.144, 8.512, 8.518, -8.386,
42+
7.889, -5.721, 5.507, 4.295,
43+
-6.624, -8.463, 7.474, 9.879,
44+
4.534, -5.908, -9.388, 2.356,
45+
46+
7.497, 8.186, -8.658, -4.796,
47+
-8.248, -9.797, -7.907, -4.513,
48+
3.469, 7.633, 7.244, -6.844,
49+
-7.173, 4.450, 6.705, -7.035
50+
},
51+
// clang-format on
52+
53+
// clang-format off
54+
{
55+
42.00, 8.512, 8.518, -8.386,
56+
7.889, 42.00, 5.507, 4.295,
57+
-6.624, -8.463, 42.00, 9.879,
58+
4.534, -5.908, -9.388, 42.00,
59+
60+
42.00, 8.186, -8.658, -4.796,
61+
-8.248, 42.00, -7.907, -4.513,
62+
3.469, 7.633, 42.00, -6.844,
63+
-7.173, 4.450, 6.705, 42.00
64+
},
65+
// clang-format on
66+
}
67+
};
68+
69+
TEST_P(AdaptiveRKVDiversityFillDiagonalTest, FillsDiagonal) {
70+
auto test_struct = GetParam();
71+
ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape));
72+
ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape));
73+
74+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE, DEFAULT_EVICTION_SIZE, DEFAULT_BLOCK_SIZE);
75+
76+
std::vector<double> test_out_data = test_struct.in_data;
77+
calculator.fill_diagonal_(test_out_data.data(),
78+
test_struct.in_shape,
79+
42.0);
80+
EXPECT_EQ(test_out_data, test_struct.ref_out_data);
81+
}
82+
83+
INSTANTIATE_TEST_SUITE_P(VariousInputs,
84+
AdaptiveRKVDiversityFillDiagonalTest,
85+
::testing::ValuesIn(FILL_DIAGONAL_TEST_CASES));
86+
87+
struct FillLowValuesWithZerosTestData {
88+
ov::Shape in_shape;
89+
std::vector<double> in_data;
90+
ov::Shape means_shape;
91+
std::vector<double> means;
92+
std::vector<double> ref_out_data;
93+
};
94+
95+
using AdaptiveRKVFillLowValuesWithZerosTest = ::testing::TestWithParam<FillLowValuesWithZerosTestData>;
96+
97+
std::vector<FillLowValuesWithZerosTestData> FILL_LOW_VALUES_WITH_ZEROS_TEST_CASES = {
98+
{
99+
{2, 4, 4},
100+
// clang-format off
101+
{
102+
4.534, -5.908, -9.388, 2.356,
103+
-6.624, -8.463, 7.474, 9.879,
104+
7.889, -5.721, 5.507, 4.295,
105+
3.144, 8.512, 8.518, -8.386,
106+
107+
-7.173, 4.450, 6.705, -7.035,
108+
3.469, 7.633, 7.244, -6.844,
109+
-8.248, -9.797, -7.907, -4.513,
110+
7.497, 8.186, -8.658, -4.796,
111+
},
112+
// clang-format on
113+
114+
{2, 4},
115+
116+
// clang-format off
117+
{
118+
-2.1015,
119+
0.5665,
120+
2.9925,
121+
2.947,
122+
123+
-0.76325,
124+
2.8755,
125+
-7.61625,
126+
0.55725
127+
},
128+
129+
// clang-format off
130+
{
131+
4.534, 0.000, 0.000, 2.356,
132+
0.000, 0.000, 7.474, 9.879,
133+
7.889, 0.000, 5.507, 4.295,
134+
3.144, 8.512, 8.518, 0.000,
135+
136+
0.000, 4.450, 6.705, 0.000,
137+
3.469, 7.633, 7.244, 0.000,
138+
0.000, 0.000, 0.000, -4.513,
139+
7.497, 8.186, 0.000, 0.000,
140+
},
141+
// clang-format on
142+
},
143+
};
144+
145+
TEST_P(AdaptiveRKVFillLowValuesWithZerosTest, FillsLowValuesWithZero) {
146+
auto test_struct = GetParam();
147+
ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape));
148+
ASSERT_EQ(test_struct.means.size(), ov::shape_size(test_struct.means_shape));
149+
ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape));
150+
151+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE,
152+
DEFAULT_EVICTION_SIZE,
153+
DEFAULT_BLOCK_SIZE);
154+
std::vector<double> test_out_data = test_struct.in_data;
155+
calculator.fill_low_values_with_zeros_(test_out_data.data(), test_struct.in_shape, test_struct.means.data(), test_struct.means_shape);
156+
157+
EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_out_data));
158+
}
159+
160+
INSTANTIATE_TEST_SUITE_P(VariousInputs,
161+
AdaptiveRKVFillLowValuesWithZerosTest,
162+
::testing::ValuesIn(FILL_LOW_VALUES_WITH_ZEROS_TEST_CASES));
163+
164+
165+
struct BlockSumTestData {
166+
ov::Shape in_shape;
167+
std::vector<double> in_data;
168+
size_t block_size;
169+
ov::Shape out_shape;
170+
std::vector<double> ref_out_data;
171+
};
172+
173+
using AdaptiveRKVBlockSumTest = ::testing::TestWithParam<BlockSumTestData>;
174+
175+
std::vector<BlockSumTestData> BLOCK_SUM_TEST_CASES = {
176+
{
177+
{8, 8},
178+
// clang-format off
179+
{
180+
0.1117, 0.0780, 0.1347, 0.0885, 0.1942, 0.0922, 0.1184, 0.1824,
181+
0.1488, 0.1766, 0.0852, 0.1239, 0.0930, 0.1220, 0.1367, 0.1138,
182+
0.1410, 0.0861, 0.0774, 0.1325, 0.1478, 0.1689, 0.0885, 0.1579,
183+
0.1248, 0.1038, 0.1842, 0.0935, 0.1813, 0.0890, 0.0897, 0.1336,
184+
0.0905, 0.1049, 0.1263, 0.0953, 0.1018, 0.1297, 0.1659, 0.1855,
185+
0.1373, 0.1791, 0.1005, 0.1286, 0.1492, 0.1373, 0.0820, 0.0860,
186+
0.0997, 0.1285, 0.0786, 0.1366, 0.1963, 0.0904, 0.1488, 0.1211,
187+
0.1859, 0.1174, 0.1364, 0.0930, 0.1028, 0.1034, 0.1699, 0.0912
188+
},
189+
// clang-format on
190+
191+
/* block_size = */ 2,
192+
193+
{4, 8},
194+
195+
// clang-format off
196+
{
197+
-0.2605, -0.2546, -0.2199, -0.2124, -0.2872, -0.2142, -0.2551, -0.2962,
198+
-0.2658, -0.1899, -0.2616, -0.226, -0.3291, -0.2579, -0.1782, -0.2915,
199+
-0.2278, -0.284 , -0.2268, -0.2239, -0.251, -0.267, -0.2479, -0.2715,
200+
-0.2856, -0.2459, -0.215, -0.2296, -0.2991, -0.1938, -0.3187, -0.2123
201+
202+
},
203+
},
204+
};
205+
206+
TEST_P(AdaptiveRKVBlockSumTest, BlockSumIsCorrect) {
207+
auto test_struct = GetParam();
208+
ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape));
209+
ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape));
210+
211+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(DEFAULT_START_SIZE,
212+
DEFAULT_EVICTION_SIZE,
213+
test_struct.block_size);
214+
std::vector<double> test_out_data(test_struct.ref_out_data.size());
215+
calculator.block_sum_diversity_values(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape);
216+
217+
EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data));
218+
}
219+
220+
INSTANTIATE_TEST_SUITE_P(VariousInputs,
221+
AdaptiveRKVBlockSumTest,
222+
::testing::ValuesIn(BLOCK_SUM_TEST_CASES));
223+
224+
struct DiversityCalculateTestData {
225+
ov::Shape in_shape;
226+
std::vector<double> in_data;
227+
double threshold;
228+
229+
};
230+
231+
struct E2EDiversityTestData {
232+
ov::Shape k_shape;
233+
std::vector<double> k_data;
234+
size_t start_size;
235+
size_t eviction_size;
236+
std::vector<std::vector<double>> ref_diversity_data;
237+
};
238+
239+
using AdaptiveRKVE2EDiversityTest = ::testing::TestWithParam<E2EDiversityTestData>;
240+
241+
std::vector<E2EDiversityTestData> E2E_DIVERSITY_TEST_CASES = {{
242+
{2, 10, 4},
243+
// clang-format off
244+
{
245+
4.949, -7.294, -6.330, 3.757,
246+
-3.561, 1.029, 5.030, -9.483,
247+
5.350, -2.745, -1.404, -7.788,
248+
-1.086, 4.576, -8.726, -8.815,
249+
3.144, 8.512, 8.518, -8.386,
250+
7.889, -5.721, 5.507, 4.295,
251+
-6.624, -8.463, 7.474, 9.879,
252+
4.534, -5.908, -9.388, 2.356,
253+
7.497, 8.186, -8.658, -4.796,
254+
-8.248, -9.797, -7.907, -4.513,
255+
256+
3.469, 7.633, 7.244, -6.844,
257+
-7.173, 4.450, 6.705, -7.035,
258+
8.773, -7.571, -9.878, -9.584,
259+
0.807, 8.059, -7.172, 4.303,
260+
-3.323, -8.852, 1.167, -1.126,
261+
-4.428, 9.678, -6.547, 0.037,
262+
-8.152, -9.865, 3.694, -7.650,
263+
0.359, 8.018, -7.152, -6.242,
264+
-9.120, -7.228, -9.186, 3.202,
265+
-9.304, -0.401, -5.287, 6.834
266+
},
267+
// clang-format on
268+
269+
/* start_size = */ 2,
270+
/* eviction_size = */ 6,
271+
{
272+
{}
273+
},
274+
}};
275+
276+
TEST_P(AdaptiveRKVE2EDiversityTest, CalculatesDiversityCorrectly) {
277+
auto test_struct = GetParam();
278+
ov::reference::AdaptiveRKVDiversityCalculator<double> calculator(test_struct.start_size,
279+
test_struct.eviction_size,
280+
DEFAULT_BLOCK_SIZE);
281+
282+
auto test_diversity = calculator.calculate_block_diversity(test_struct.k_data.data(), test_struct.k_shape);
283+
ASSERT_EQ(test_diversity.size(), test_struct.ref_diversity_data.size());
284+
for (size_t i = 0; i < test_diversity.size(); i++) {
285+
ASSERT_EQ(test_diversity[i].size(), test_struct.ref_diversity_data[i].size());
286+
}
287+
288+
for (size_t i = 0; i < test_diversity.size(); i++) {
289+
EXPECT_THAT(test_diversity[i], ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_diversity_data[i]));
290+
}
291+
292+
};
293+
294+
INSTANTIATE_TEST_SUITE_P(VariousInputs, AdaptiveRKVE2EDiversityTest, ::testing::ValuesIn(E2E_DIVERSITY_TEST_CASES));
295+
}

0 commit comments

Comments
 (0)