diff --git a/src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp b/src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp new file mode 100644 index 00000000000000..2a0c4b7bca56a2 --- /dev/null +++ b/src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp @@ -0,0 +1,220 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/op/util/attr_types.hpp" +#include "openvino/reference/matmul.hpp" +#include "openvino/reference/normalize_l2.hpp" +#include "openvino/reference/reduce_mean.hpp" +#include "openvino/reference/slice.hpp" + +namespace ov::reference { + +/** @brief Reference implementation of the Adaptive R-KV token diversity calculation mechanism + * (https://arxiv.org/pdf/2505.24133v3) */ +template +class AdaptiveRKVDiversityCalculator { +public: + /** @param start_size Size, in tokens, of the key cache area that will be ignored for purposes of diversity + * calculation, starting from the beginning of the token dimension ("start area"). Must be a multiple of + * `block_size`. + * @param eviction_size Size, in tokens, from the beginning of the start area, the tokens in which will be + * considered for purposes of diversity calculation ("eviction area"). The rest of the tokens after the eviction + * area, if any, are ignored. Must be a multiple of `block_size`. + * @param block_size Block size of the underlying paged attention implementation. The diversity values will be + * sum-reduced from per-token values to per-block values based on this number of tokens in a block. + * */ + AdaptiveRKVDiversityCalculator(size_t start_size, size_t eviction_size, size_t block_size) + : m_start_size(start_size), + m_eviction_size(eviction_size), + m_block_size(block_size) { + OPENVINO_ASSERT(start_size % block_size == 0); + OPENVINO_ASSERT(eviction_size % block_size == 0); + } + + /** Fills the diagonal of each square matrix slice (at ranks 1 and 2, zero-based) of the input rank-3 tensor with + * a provided value. The operation is done in-place. + * @param in_out Pointer to the matrix data. + * @param in_out_shape Shape of the matrix data. Expected shape is [num_heads, token_dim, token_dim]. + * @param val Value to fill in the diagonal positions. + */ + void fill_diagonal_(T* in_out, const Shape& in_out_shape, T val) { + OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim] + OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); // [num_heads, token_dim, token_dim] + + for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) { + size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2]; + for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) { + size_t diagonal_element_offset = token_dim_idx + token_dim_idx * in_out_shape[1]; + auto diagonal_element_ptr = in_out + in_head_offset + diagonal_element_offset; + *diagonal_element_ptr = val; + } + } + } + + /** For a rank-3 tensor, zeroes out the values that are less than the mean of the values of the corresponding slice + * at rank 2 (zero-based). Ranks 1 and 2 of the input tensor must be equal. Mean values are computed and provided + * externally. The operation is done in-place. + * @param in_out Pointer to the tensor data. + * @param in_out_shape Shape of the tensor data. Expected shape is [num_heads, token_dim, token_dim]. + * @param means Pointer to the tensor data containing the means of each slice of the `in_out` tensor along its rank + * 2 (zero-based). + * @param means_shape Shape of the means tensor. Expected shape is [num_heads, token_dim]. + */ + void fill_low_values_with_zeros_(T* in_out, const Shape& in_out_shape, const T* means, const Shape& means_shape) { + OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim] + OPENVINO_ASSERT(in_out_shape[1] == in_out_shape[2]); + OPENVINO_ASSERT(means_shape.size() == 2); // [num_heads, token_dim] + OPENVINO_ASSERT(means_shape[0] == in_out_shape[0]); + OPENVINO_ASSERT(means_shape[1] == in_out_shape[1]); + + for (size_t head_idx = 0; head_idx < in_out_shape[0]; head_idx++) { + size_t in_head_offset = head_idx * in_out_shape[1] * in_out_shape[2]; + size_t means_head_offset = head_idx * means_shape[1]; + for (size_t token_dim_idx = 0; token_dim_idx < in_out_shape[1]; token_dim_idx++) { + T mean_val = means[means_head_offset + token_dim_idx]; + size_t token_offset = token_dim_idx * in_out_shape[2]; + for (size_t reduced_dim_idx = 0; reduced_dim_idx < in_out_shape[2]; reduced_dim_idx++) { + size_t target_offset = in_head_offset + token_offset + reduced_dim_idx; + T filled_val = in_out[target_offset]; + in_out[target_offset] = filled_val >= mean_val ? filled_val : 0.0; + } + } + } + } + + /** For a square matrix, sums each `block_size`-sized group of matrix rows to produce a row in the output matrix. + * In the overall algorithm context, each summed value represents diversity (the negative of inter-token cosine + * similarity), where larger absolute values indicate greater diversity. + * @param in_data Pointer to the matrix data. + * @param in_shape Shape of the matrix data. Expected shape is [token_dim, token_dim], where token_dim must be a + * multiple of `block_size`. + * @param out Pointer to the output matrix data. + * @param out_shape Shape of the output matrix. Expected shape is [token_dim / block_size, token_dim]. + */ + void block_sum_diversity_values(const T* in_data, const Shape& in_shape, T* out, const Shape& out_shape) { + OPENVINO_ASSERT(in_shape.size() == 2); // [token_dim, token_dim] + OPENVINO_ASSERT(in_shape[0] == in_shape[1]); + OPENVINO_ASSERT(in_shape[0] % m_block_size == 0); + + OPENVINO_ASSERT(out_shape.size() == 2); // [block_dim, token_dim] + OPENVINO_ASSERT(out_shape[0] == in_shape[0] / m_block_size); + OPENVINO_ASSERT(out_shape[1] == in_shape[1]); + + std::memset(out, 0, out_shape[0] * out_shape[1] * sizeof(T)); + + for (size_t out_block_dim_idx = 0; out_block_dim_idx < out_shape[0]; out_block_dim_idx++) { + size_t out_block_offset = out_block_dim_idx * out_shape[1]; + for (size_t out_token_dim_idx = 0; out_token_dim_idx < out_shape[1]; out_token_dim_idx++) { + size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1]; + for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) { + size_t source_offset = in_block_offset + in_token_in_block_idx * in_shape[1] + out_token_dim_idx; + out[out_block_offset + out_token_dim_idx] -= in_data[source_offset]; + } + } + } + } + + /** Calculates token diversity in the eviction area, partially aggregating the results per-block. The resulting + * diversity values have the shape of [num_eviction_blocks (== eviction_size / block_size), eviction_size]. Note + * that the 1-st rank is left unaggregated when compared to the full diversity calculation algorithm. The reason + * for this is as follows. The final per-block diversity value computation relies on knowing the subset of blocks + * in the eviction area that will be retained regardless of calculated diversity. This subset must be filtered out + * from the rank-1 dimension when performing reduce-mean in the original algorithm to get 1 diversity value per + * block in the eviction area. Due to implementation specifics the paged attention kernel does not know ahead of + * time which blocks will be "retained" - this information is only available on the openvino.genai level after the + * PA kernel has executed. Therefore the PA kernel will provide raw per-token values on the rank 1 of the returned + * diversity value matrix and delegate the final reduce-mean and filtering to the openvino.genai level. + * @param key_data Pointer to the key cache tensor data + * @param key_shape Shape of the key input tensor data. Expected shape is [num_heads, num_key_tokens, head_size], + * where `num_key_tokens` must be no less than `start_size + eviction_size`. + * @return A rank-2 matrix in the std::vector representation with dimensions [eviction_size / block_size, + * eviction_size] containing the diversity values. The values are expected to be further mean-reduced along rank 1 + * (zero-based) at the point in time when the subset of blocks to be exclusively retained is known. + */ + std::vector> calculate_block_diversity(const T* key_data, const Shape& key_shape) { + OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim] + OPENVINO_ASSERT(key_shape[1] >= m_start_size + m_eviction_size); + + auto normalized_key_data_buf = allocate_buf(key_shape); + // Should be safe to use this in-place + ov::reference::normalize_l2(key_data, + normalized_key_data_buf.get(), + key_shape, + {2}, + std::numeric_limits::epsilon(), + ov::op::EpsMode::ADD); + + Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]}; + auto cos_similar_buf = allocate_buf(cos_similar_shape); + ov::reference::matmul(normalized_key_data_buf.get(), + normalized_key_data_buf.get(), + cos_similar_buf.get(), + key_shape, + key_shape, + cos_similar_shape, + /* transpose_arg0 = */ false, + /* transpose_arg1 = */ true); + normalized_key_data_buf.reset(); + + Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size}; + auto evictable_subset_buf = allocate_buf(evictable_subset_shape); + ov::reference::slice(reinterpret_cast(cos_similar_buf.get()), + cos_similar_shape, + reinterpret_cast(evictable_subset_buf.get()), + evictable_subset_shape, + sizeof(T), + /* starts = */ {static_cast(m_start_size), static_cast(m_start_size)}, + /* steps = */ {1, 1}, + /* axes = */ {1, 2}); // stops are defined by output shape + cos_similar_buf.reset(); + + fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0); + + Shape means_shape = {key_shape[0], m_eviction_size}; + auto means_buf = allocate_buf(means_shape); + ov::reference::reduce_mean(evictable_subset_buf.get(), means_buf.get(), evictable_subset_shape, {2}); + + fill_low_values_with_zeros_(evictable_subset_buf.get(), evictable_subset_shape, means_buf.get(), means_shape); + means_buf.reset(); + + Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size}; + auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape); + ov::reference::reduce_mean(evictable_subset_buf.get(), + aggregated_token_similarities_buf.get(), + evictable_subset_shape, + {0}); + evictable_subset_buf.reset(); + + Shape block_diversity_shape = {m_eviction_size / m_block_size, m_eviction_size}; + auto block_diversity_buf = allocate_buf(block_diversity_shape); + block_sum_diversity_values(aggregated_token_similarities_buf.get(), + aggregated_token_similarities_shape, + block_diversity_buf.get(), + block_diversity_shape); + std::vector> retval(block_diversity_shape[0], std::vector(block_diversity_shape[1])); + for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) { + for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) { + retval[block_idx][token_idx] = block_diversity_buf[block_idx * block_diversity_shape[1] + token_idx]; + } + } + + return retval; + } + + /** + * @param shape Shape of a tensor + * @return A shared_ptr owning a buffer that can be used to store tensor data for the given shape. + * */ + std::shared_ptr allocate_buf(const Shape& shape) { + return std::shared_ptr(new T[ov::shape_size(shape)]); + } + + size_t m_start_size; + size_t m_eviction_size; + size_t m_block_size; +}; + +} // namespace ov::reference diff --git a/src/core/tests/reference/adaptive_rkv_diversity.cpp b/src/core/tests/reference/adaptive_rkv_diversity.cpp new file mode 100644 index 00000000000000..d3f71cdd3758c9 --- /dev/null +++ b/src/core/tests/reference/adaptive_rkv_diversity.cpp @@ -0,0 +1,438 @@ +// Copyright (C) 2018-2025 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include + +namespace adaptive_rkv_test { +size_t DEFAULT_BLOCK_SIZE = 2; +size_t DEFAULT_START_SIZE = 2; +size_t DEFAULT_EVICTION_SIZE = 10; + +TEST(AdaptiveRKVE2ESmokeTest, CalculatesDiversityWithoutThrowing) { + ov::reference::AdaptiveRKVDiversityCalculator calculator(DEFAULT_START_SIZE, + DEFAULT_EVICTION_SIZE, + DEFAULT_BLOCK_SIZE); + + ov::Shape mock_shape{2, (DEFAULT_START_SIZE + DEFAULT_EVICTION_SIZE) * 2, 8}; + std::vector mock_data(ov::shape_size(mock_shape), 1.0); + + EXPECT_NO_THROW(calculator.calculate_block_diversity(mock_data.data(), mock_shape)); +}; + +struct FillDiagonalTestData { + ov::Shape in_shape; + std::vector in_data; + std::vector ref_out_data; +}; + +using AdaptiveRKVDiversityFillDiagonalTest = ::testing::TestWithParam; + +std::vector FILL_DIAGONAL_TEST_CASES = {{ + {2, 4, 4}, + // clang-format off + { + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035 + }, + // clang-format on + + // clang-format off + { + 42.00, 8.512, 8.518, -8.386, + 7.889, 42.00, 5.507, 4.295, + -6.624, -8.463, 42.00, 9.879, + 4.534, -5.908, -9.388, 42.00, + + 42.00, 8.186, -8.658, -4.796, + -8.248, 42.00, -7.907, -4.513, + 3.469, 7.633, 42.00, -6.844, + -7.173, 4.450, 6.705, 42.00 + }, + // clang-format on +}}; + +TEST_P(AdaptiveRKVDiversityFillDiagonalTest, FillsDiagonal) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape)); + + ov::reference::AdaptiveRKVDiversityCalculator calculator(DEFAULT_START_SIZE, + DEFAULT_EVICTION_SIZE, + DEFAULT_BLOCK_SIZE); + + std::vector test_out_data = test_struct.in_data; + calculator.fill_diagonal_(test_out_data.data(), test_struct.in_shape, 42.0); + EXPECT_EQ(test_out_data, test_struct.ref_out_data); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + AdaptiveRKVDiversityFillDiagonalTest, + ::testing::ValuesIn(FILL_DIAGONAL_TEST_CASES)); + +struct FillLowValuesWithZerosTestData { + ov::Shape in_shape; + std::vector in_data; + ov::Shape means_shape; + std::vector means; + std::vector ref_out_data; +}; + +using AdaptiveRKVFillLowValuesWithZerosTest = ::testing::TestWithParam; + +std::vector FILL_LOW_VALUES_WITH_ZEROS_TEST_CASES = { + { + {2, 4, 4}, + // clang-format off + { + 4.534, -5.908, -9.388, 2.356, + -6.624, -8.463, 7.474, 9.879, + 7.889, -5.721, 5.507, 4.295, + 3.144, 8.512, 8.518, -8.386, + + -7.173, 4.450, 6.705, -7.035, + 3.469, 7.633, 7.244, -6.844, + -8.248, -9.797, -7.907, -4.513, + 7.497, 8.186, -8.658, -4.796, + }, + // clang-format on + + {2, 4}, + + // clang-format off + { + -2.1015, + 0.5665, + 2.9925, + 2.947, + + -0.76325, + 2.8755, + -7.61625, + 0.55725 + }, + + // clang-format off + { + 4.534, 0.000, 0.000, 2.356, + 0.000, 0.000, 7.474, 9.879, + 7.889, 0.000, 5.507, 4.295, + 3.144, 8.512, 8.518, 0.000, + + 0.000, 4.450, 6.705, 0.000, + 3.469, 7.633, 7.244, 0.000, + 0.000, 0.000, 0.000, -4.513, + 7.497, 8.186, 0.000, 0.000, + }, + // clang-format on + }, +}; + +TEST_P(AdaptiveRKVFillLowValuesWithZerosTest, FillsLowValuesWithZero) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.means.size(), ov::shape_size(test_struct.means_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.in_shape)); + + ov::reference::AdaptiveRKVDiversityCalculator calculator(DEFAULT_START_SIZE, + DEFAULT_EVICTION_SIZE, + DEFAULT_BLOCK_SIZE); + std::vector test_out_data = test_struct.in_data; + calculator.fill_low_values_with_zeros_(test_out_data.data(), + test_struct.in_shape, + test_struct.means.data(), + test_struct.means_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-8), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + AdaptiveRKVFillLowValuesWithZerosTest, + ::testing::ValuesIn(FILL_LOW_VALUES_WITH_ZEROS_TEST_CASES)); + +struct BlockSumTestData { + ov::Shape in_shape; + std::vector in_data; + size_t block_size; + ov::Shape out_shape; + std::vector ref_out_data; +}; + +using AdaptiveRKVBlockSumTest = ::testing::TestWithParam; + +std::vector BLOCK_SUM_TEST_CASES = { + { + {8, 8}, + // clang-format off + { + 0.1117, 0.0780, 0.1347, 0.0885, 0.1942, 0.0922, 0.1184, 0.1824, + 0.1488, 0.1766, 0.0852, 0.1239, 0.0930, 0.1220, 0.1367, 0.1138, + 0.1410, 0.0861, 0.0774, 0.1325, 0.1478, 0.1689, 0.0885, 0.1579, + 0.1248, 0.1038, 0.1842, 0.0935, 0.1813, 0.0890, 0.0897, 0.1336, + 0.0905, 0.1049, 0.1263, 0.0953, 0.1018, 0.1297, 0.1659, 0.1855, + 0.1373, 0.1791, 0.1005, 0.1286, 0.1492, 0.1373, 0.0820, 0.0860, + 0.0997, 0.1285, 0.0786, 0.1366, 0.1963, 0.0904, 0.1488, 0.1211, + 0.1859, 0.1174, 0.1364, 0.0930, 0.1028, 0.1034, 0.1699, 0.0912 + }, + // clang-format on + + /* block_size = */ 2, + + {4, 8}, + + // clang-format off + { + -0.2605, -0.2546, -0.2199, -0.2124, -0.2872, -0.2142, -0.2551, -0.2962, + -0.2658, -0.1899, -0.2616, -0.226, -0.3291, -0.2579, -0.1782, -0.2915, + -0.2278, -0.284 , -0.2268, -0.2239, -0.251, -0.267, -0.2479, -0.2715, + -0.2856, -0.2459, -0.215, -0.2296, -0.2991, -0.1938, -0.3187, -0.2123 + + }, + }, +}; + +TEST_P(AdaptiveRKVBlockSumTest, BlockSumIsCorrect) { + auto test_struct = GetParam(); + ASSERT_EQ(test_struct.in_data.size(), ov::shape_size(test_struct.in_shape)); + ASSERT_EQ(test_struct.ref_out_data.size(), ov::shape_size(test_struct.out_shape)); + + ov::reference::AdaptiveRKVDiversityCalculator calculator(DEFAULT_START_SIZE, + DEFAULT_EVICTION_SIZE, + test_struct.block_size); + std::vector test_out_data(test_struct.ref_out_data.size()); + calculator.block_sum_diversity_values(test_struct.in_data.data(), test_struct.in_shape, test_out_data.data(), test_struct.out_shape); + + EXPECT_THAT(test_out_data, ::testing::Pointwise(::testing::DoubleNear(1e-5), test_struct.ref_out_data)); +} + +INSTANTIATE_TEST_SUITE_P(VariousInputs, + AdaptiveRKVBlockSumTest, + ::testing::ValuesIn(BLOCK_SUM_TEST_CASES)); + +struct DiversityCalculateTestData { + ov::Shape in_shape; + std::vector in_data; + double threshold; + +}; + +struct E2EDiversityTestData { + ov::Shape k_shape; + std::vector k_data; + size_t start_size; + size_t eviction_size; + std::vector> ref_diversity_data; +}; + +using AdaptiveRKVE2EDiversityTest = ::testing::TestWithParam; + +std::vector E2E_DIVERSITY_TEST_CASES = { + // basic + { + {1, 4, 1}, + // clang-format off + { + 1.0, + 1.0, + 1.0, + 1.0 + }, + /* start_size = */ 2, + /* eviction_size = */ 2, + {{-1.0, -1.0}} + }, + // larger basic + { + {1, 6, 1}, + // clang-format off + { + 6.5, + -11.0, + 1.0, + 1.0, + 1.0, + 1.0, + }, + /* start_size = */ 2, + /* eviction_size = */ 4, + {{-1.0, -1.0, -2.0, -2.0}, + {-2.0, -2.0, -1.0, -1.0}} + }, + // two heads basic + { + {2, 8, 1}, + // clang-format off + { + 6.5, + -11.0, + 1.0, + 1.0, + 1.0, + 1.0, + 42.0, + -13.7, + + 1337.0, + -1256.9, + -1.0, + -1.0, + -1.0, + -1.0, + 0.2, + 0.0 + }, + /* start_size = */ 2, + /* eviction_size = */ 4, + {{-1.0, -1.0, -2.0, -2.0}, + {-2.0, -2.0, -1.0, -1.0}} + }, + // zeroed second head (where it matters) + { + {2, 8, 1}, + // clang-format off + { + 6.5, + -11.0, + 1.0, + 1.0, + 1.0, + 1.0, + 42.0, + -13.7, + + 1337.0, + -1256.9, + 0.0, + 0.0, + 0.0, + 0.0, + 0.2, + 0.0 + }, + /* start_size = */ 2, + /* eviction_size = */ 4, + {{-0.5, -0.5, -1.0, -1.0}, + {-1.0, -1.0, -0.5, -0.5}} + }, + // more embedding dimensions + { + {2, 8, 4}, + // clang-format off + { + 6.5, 8.3, 5.1, -7.4, + -11.0, 1.9, 7.1, 4.8, + 8.0, 8.0, 8.0, 8.0, + 8.0, 8.0, 8.0, 8.0, + 8.0, 8.0, 8.0, 8.0, + 8.0, 8.0, 8.0, 8.0, + 42.0, -41.7, 8.3, 1.0, + -13.7, 0.0, 0.0, 15.1, + + 1337.0, -1.9, -1.4, 475.1, + -1256.9, 1.0, 789.0, 1421.3, + -2.0, -2.0, -2.0, -2.0, + -2.0, -2.0, -2.0, -2.0, + -2.0, -2.0, -2.0, -2.0, + -2.0, -2.0, -2.0, -2.0, + 0.2, -81.3, 74.3, -641.1, + 0.0, 14.7, 98.1, -27.7 + }, + /* start_size = */ 2, + /* eviction_size = */ 4, + {{-1.0, -1.0, -2.0, -2.0}, + {-2.0, -2.0, -1.0, -1.0}} + }, + // orthogonal tokens + { + {2, 8, 4}, + // clang-format off + { + 6.5, 8.3, 5.1, -7.4, + -11.0, 1.9, 7.1, 4.8, + 8.0, 0.0, 0.0, 0.0, + 0.0, 0.0, -18.0, 0.0, + 0.0, 0.0, 0.0, 0.1, + 0.0, 1288.0, 0.0, 0.0, + 42.0, -41.7, 8.3, 1.0, + -13.7, 0.0, 0.0, 15.1, + + 1337.0, -1.9, -1.4, 475.1, + -1256.9, 1.0, 789.0, 1421.3, + 0.0, 0.0, 2.0, 0.0, + 0.0, -12.0, 0.0, 0.0, + 12.8, 0.0, 0.0, 0.0, + 0.0, 0.0, 0.0, 65.5, + 0.2, -81.3, 74.3, -641.1, + 0.0, 14.7, 98.1, -27.7 + }, + /* start_size = */ 2, + /* eviction_size = */ 4, + {{0.0, 0.0, 0.0, 0.0}, + {0.0, 0.0, 0.0, 0.0}} + }, + // random excel-checked golden + { + {2, 10, 4}, + // clang-format off + { + 4.949, -7.294, -6.330, 3.757, + -3.561, 1.029, 5.030, -9.483, + 5.350, -2.745, -1.404, -7.788, + -1.086, 4.576, -8.726, -8.815, + 3.144, 8.512, 8.518, -8.386, + 7.889, -5.721, 5.507, 4.295, + -6.624, -8.463, 7.474, 9.879, + 4.534, -5.908, -9.388, 2.356, + 7.497, 8.186, -8.658, -4.796, + -8.248, -9.797, -7.907, -4.513, + + 3.469, 7.633, 7.244, -6.844, + -7.173, 4.450, 6.705, -7.035, + 8.773, -7.571, -9.878, -9.584, + 0.807, 8.059, -7.172, 4.303, + -3.323, -8.852, 1.167, -1.126, + -4.428, 9.678, -6.547, 0.037, + -8.152, -9.865, 3.694, -7.650, + 0.359, 8.018, -7.152, -6.242, + -9.120, -7.228, -9.186, 3.202, + -9.304, -0.401, -5.287, 6.834 + }, + // clang-format on + + /* start_size = */ 2, + /* eviction_size = */ 6, + {{-0.237145, -0.237145, -0.352696, -0.487902, -0.072365, -0.707192}, + {-0.334657, -0.505941, 0, 0.036135, -0.634881, -0.490221}, + {-0.380811, -0.398746801, -0.432080003, -0.693021748, 0, 0.067216441}}, + }}; + +TEST_P(AdaptiveRKVE2EDiversityTest, CalculatesDiversityCorrectly) { + auto test_struct = GetParam(); + ov::reference::AdaptiveRKVDiversityCalculator calculator(test_struct.start_size, + test_struct.eviction_size, + DEFAULT_BLOCK_SIZE); + + auto test_diversity = calculator.calculate_block_diversity(test_struct.k_data.data(), test_struct.k_shape); + ASSERT_EQ(test_diversity.size(), test_struct.ref_diversity_data.size()); + for (size_t i = 0; i < test_diversity.size(); i++) { + ASSERT_EQ(test_diversity[i].size(), test_struct.ref_diversity_data[i].size()); + } + + for (size_t i = 0; i < test_diversity.size(); i++) { + EXPECT_THAT(test_diversity[i], + ::testing::Pointwise(::testing::DoubleNear(1e-6), test_struct.ref_diversity_data[i])); + } +}; + +INSTANTIATE_TEST_SUITE_P(VariousInputs, AdaptiveRKVE2EDiversityTest, ::testing::ValuesIn(E2E_DIVERSITY_TEST_CASES)); +} // namespace adaptive_rkv_test