Skip to content

Commit f1c12d5

Browse files
committed
Add basic tests
1 parent 05cf16b commit f1c12d5

File tree

2 files changed

+456
-8
lines changed

2 files changed

+456
-8
lines changed

src/core/reference/include/openvino/reference/adaptive_rkv_diversity.hpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include <memory>
1010
#include <queue>
1111

12+
#include "openvino/op/util/attr_types.hpp"
1213
#include "openvino/reference/matmul.hpp"
1314
#include "openvino/reference/normalize_l2.hpp"
1415
#include "openvino/reference/reduce_mean.hpp"
@@ -59,7 +60,7 @@ class AdaptiveRKVDiversityCalculator {
5960
* @param out_shape Shape of the output tensor data. Expected shape is [num_heads, num_query_tokens / block_size,
6061
* num_key_tokens / block_size].
6162
*/
62-
void fill_diagonal_(const T* in_out,
63+
void fill_diagonal_(T* in_out,
6364
const Shape& in_out_shape,
6465
T val) {
6566
OPENVINO_ASSERT(in_out_shape.size() == 3); // [num_heads, token_dim, token_dim]
@@ -76,7 +77,7 @@ class AdaptiveRKVDiversityCalculator {
7677
}
7778
}
7879

79-
void fill_low_values_with_zeros_(const T* in_out,
80+
void fill_low_values_with_zeros_(T* in_out,
8081
const Shape& in_out_shape,
8182
const T* means,
8283
const Shape& means_shape) {
@@ -121,7 +122,7 @@ class AdaptiveRKVDiversityCalculator {
121122
size_t in_block_offset = (out_block_dim_idx * m_block_size) * out_shape[1];
122123
for (size_t in_token_in_block_idx = 0; in_token_in_block_idx < m_block_size; in_token_in_block_idx++) {
123124
size_t source_offset = in_block_offset + in_token_in_block_idx * processed_similarity_token_data_shape[1] + out_token_dim_idx;
124-
out[out_block_offset + out_token_dim_idx] += processed_similarity_token_data[source_offset];
125+
out[out_block_offset + out_token_dim_idx] -= processed_similarity_token_data[source_offset];
125126
}
126127
}
127128
}
@@ -146,19 +147,22 @@ class AdaptiveRKVDiversityCalculator {
146147
std::vector<std::vector<T>> calculate_block_diversity(const T* key_data,
147148
const Shape& key_shape) {
148149
OPENVINO_ASSERT(key_shape.size() == 3); // [num_heads, key_token_len, head_dim]
149-
OPENVINO_ASSERT(key_shape[1] >= m_block_size * (m_start_size + m_eviction_size));
150+
OPENVINO_ASSERT(key_shape[1] >= m_start_size + m_eviction_size);
150151

152+
153+
auto normalized_key_data_buf = allocate_buf(key_shape);
151154
// Should be safe to use this in-place
152-
ov::reference::normalize_l2(key_data, key_data, key_shape, {2}, std::numeric_limits<T>::epsilon());
155+
ov::reference::normalize_l2(key_data, normalized_key_data_buf.get(), key_shape, {2}, std::numeric_limits<float>::epsilon(), ov::op::EpsMode::ADD);
153156

154157
Shape cos_similar_shape = {key_shape[0], key_shape[1], key_shape[1]};
155158
auto cos_similar_buf = allocate_buf(cos_similar_shape);
156-
ov::reference::matmul(key_data, key_data, cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
159+
ov::reference::matmul(normalized_key_data_buf.get(), normalized_key_data_buf.get(), cos_similar_buf.get(), key_shape, key_shape, cos_similar_shape, /* transpose_arg0 = */ false, /* transpose_arg1 = */ true);
160+
normalized_key_data_buf.reset();
157161

158162
Shape evictable_subset_shape = {key_shape[0], m_eviction_size, m_eviction_size};
159163
auto evictable_subset_buf = allocate_buf(evictable_subset_shape);
160164
// stops?
161-
ov::reference::slice(cos_similar_buf.get(), cos_similar_shape, evictable_subset_buf.get(), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
165+
ov::reference::slice(reinterpret_cast<char*>(cos_similar_buf.get()), cos_similar_shape, reinterpret_cast<char*>(evictable_subset_buf.get()), evictable_subset_shape, sizeof(T), /* starts = */ {m_start_size, m_start_size}, /* steps = */ {1, 1}, /* axes = */{1, 2});
162166
cos_similar_buf.reset();
163167

164168
fill_diagonal_(evictable_subset_buf.get(), evictable_subset_shape, 0.0);
@@ -168,6 +172,7 @@ class AdaptiveRKVDiversityCalculator {
168172
ov::reference::reduce_mean(evictable_subset_buf.get(), means_buf.get(), evictable_subset_shape, {2});
169173

170174
fill_low_values_with_zeros_(evictable_subset_buf.get(), evictable_subset_shape, means_buf.get(), means_shape);
175+
means_buf.reset();
171176

172177
Shape aggregated_token_similarities_shape = {m_eviction_size, m_eviction_size};
173178
auto aggregated_token_similarities_buf = allocate_buf(aggregated_token_similarities_shape);
@@ -180,7 +185,7 @@ class AdaptiveRKVDiversityCalculator {
180185
std::vector<std::vector<T>> retval(block_diversity_shape[0], std::vector<T>(block_diversity_shape[1]));
181186
for (size_t block_idx = 0; block_idx < block_diversity_shape[0]; block_idx++) {
182187
for (size_t token_idx = 0; token_idx < block_diversity_shape[1]; token_idx++) {
183-
retval[block_idx][token_idx] = block_diversity_buf.get() + block_idx * block_diversity_shape[1] + token_idx;
188+
retval[block_idx][token_idx] = block_diversity_buf[block_idx * block_diversity_shape[1] + token_idx];
184189
}
185190
}
186191

0 commit comments

Comments
 (0)