Skip to content

Commit 2d47114

Browse files
itikhonopraasz
andauthored
PagedAttention op, validate refactoring (#30446)
### Details: It's getting harder to review and maintain validate method of PagedAttention op. Updated unit tests ### Tickets: - *N/A* --------- Co-authored-by: Pawel Raasz <[email protected]>
1 parent 1412558 commit 2d47114

File tree

3 files changed

+170
-158
lines changed

3 files changed

+170
-158
lines changed

src/core/include/openvino/core/type/element_type.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,6 @@ class OPENVINO_API Type {
9797
std::string get_type_name() const;
9898
friend OPENVINO_API std::ostream& operator<<(std::ostream&, const Type&);
9999

100-
OPENVINO_DEPRECATED("This function is deprecated and will be removed in 2026.0.")
101100
static std::vector<const Type*> get_known_types();
102101

103102
/// \brief Checks whether this element type is merge-compatible with `t`.

src/core/src/op/paged_attention.cpp

Lines changed: 77 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,68 @@
66

77
#include "dimension_util.hpp"
88
#include "itt.hpp"
9+
#include "openvino/core/validation_util.hpp"
910
#include "openvino/op/op.hpp"
1011

12+
namespace {
13+
14+
// Validates input rank and type for a node input.
15+
// We consider that dynamic rank/type are always valid case.
16+
// Empty {} means any rank/type
17+
inline void input_check(const ov::Node* node,
18+
size_t idx,
19+
const std::string_view input_name,
20+
std::initializer_list<ov::Rank>&& allowed_ranks,
21+
const std::vector<ov::element::Type>& allowed_types) {
22+
using namespace ov;
23+
using namespace ov::util;
24+
using namespace ov::element;
25+
26+
const auto& rank = node->get_input_partial_shape(idx).rank();
27+
const auto& tp = node->get_input_element_type(idx);
28+
29+
auto rank_check = [&](const Rank& rank) {
30+
return rank.is_dynamic() || empty(allowed_ranks) || is_rank_compatible_any_of(rank.get_length(), allowed_ranks);
31+
};
32+
33+
auto type_check = [&](const Type& type) {
34+
auto it = std::find(allowed_types.begin(), allowed_types.end(), tp);
35+
return type.is_dynamic() || allowed_types.empty() || it != allowed_types.end();
36+
};
37+
38+
NODE_VALIDATION_CHECK(node,
39+
rank_check(rank),
40+
"Rank of `",
41+
input_name,
42+
"` input should be in [dynamic, ",
43+
join(allowed_ranks),
44+
"] list, but it is ",
45+
rank,
46+
".");
47+
48+
NODE_VALIDATION_CHECK(node,
49+
type_check(tp),
50+
"Element type of `",
51+
input_name,
52+
"` input should be in [dynamic, ",
53+
join(allowed_types),
54+
"] list, but it is ",
55+
tp,
56+
".");
57+
}
58+
59+
std::vector<ov::element::Type> get_real_types() {
60+
std::vector<ov::element::Type> real_types;
61+
for (const auto& type : ov::element::Type::get_known_types()) {
62+
if (type->is_real()) {
63+
real_types.push_back(*type);
64+
}
65+
}
66+
return real_types;
67+
}
68+
69+
} // namespace
70+
1171
namespace ov {
1272
namespace op {
1373

@@ -23,165 +83,25 @@ void PagedAttentionExtension::validate_and_infer_types() {
2383
"PagedAttensionExtension expects 13 or 16 inputs, but it has ",
2484
get_input_size());
2585

26-
NODE_VALIDATION_CHECK(
27-
this,
28-
get_input_partial_shape(0).rank().is_dynamic() || get_input_partial_shape(0).rank().get_length() == 2,
29-
"Rank of `query` input should be 2, but it is ",
30-
get_input_partial_shape(0).rank().get_length(),
31-
".");
32-
NODE_VALIDATION_CHECK(
33-
this,
34-
get_input_partial_shape(1).rank().is_dynamic() || get_input_partial_shape(1).rank().get_length() == 2,
35-
"Rank of `key` input should be 2, but it is ",
36-
get_input_partial_shape(1).rank().get_length(),
37-
".");
38-
NODE_VALIDATION_CHECK(
39-
this,
40-
get_input_partial_shape(2).rank().is_dynamic() || get_input_partial_shape(2).rank().get_length() == 2,
41-
"Rank of `value` input should be 2, but it is ",
42-
get_input_partial_shape(2).rank().get_length(),
43-
".");
44-
45-
NODE_VALIDATION_CHECK(
46-
this,
47-
get_input_partial_shape(3).rank().is_dynamic() || get_input_partial_shape(3).rank().get_length() >= 2,
48-
"Rank of `key_cache` input should be at least 2, but it is ",
49-
get_input_partial_shape(3).rank().get_length(),
50-
".");
51-
NODE_VALIDATION_CHECK(
52-
this,
53-
get_input_partial_shape(4).rank().is_dynamic() || get_input_partial_shape(4).rank().get_length() >= 2,
54-
"Rank of `value_cache` input should be at least 2, but it is ",
55-
get_input_partial_shape(4).rank().get_length(),
56-
".");
57-
58-
NODE_VALIDATION_CHECK(
59-
this,
60-
get_input_partial_shape(5).rank().is_dynamic() || get_input_partial_shape(5).rank().get_length() == 1,
61-
"Rank of `past_lens` input should be 1, but it is ",
62-
get_input_partial_shape(5).rank().get_length(),
63-
".");
64-
NODE_VALIDATION_CHECK(this,
65-
get_input_element_type(5).is_dynamic() || get_input_element_type(5) == element::i32,
66-
"Element type of `past_lens` input should be i32, but it is ",
67-
get_input_element_type(5),
68-
".");
69-
NODE_VALIDATION_CHECK(
70-
this,
71-
get_input_partial_shape(6).rank().is_dynamic() || get_input_partial_shape(6).rank().get_length() == 1,
72-
"Rank of `subsequence_begins` input should be 1, but it is ",
73-
get_input_partial_shape(6).rank().get_length(),
74-
".");
75-
NODE_VALIDATION_CHECK(this,
76-
get_input_element_type(6).is_dynamic() || get_input_element_type(6) == element::i32,
77-
"Element type of `subsequence_begins` input should be i32, but it is ",
78-
get_input_element_type(6),
79-
".");
80-
81-
NODE_VALIDATION_CHECK(
82-
this,
83-
get_input_partial_shape(7).rank().is_dynamic() || get_input_partial_shape(7).rank().get_length() == 1,
84-
"Rank of `block_indices` input should be 1, but it is ",
85-
get_input_partial_shape(7).rank().get_length(),
86-
".");
87-
NODE_VALIDATION_CHECK(this,
88-
get_input_element_type(7).is_dynamic() || get_input_element_type(7) == element::i32,
89-
"Element type of `block_indices` input should be i32, but it is ",
90-
get_input_element_type(7),
91-
".");
92-
NODE_VALIDATION_CHECK(
93-
this,
94-
get_input_partial_shape(8).rank().is_dynamic() || get_input_partial_shape(8).rank().get_length() == 1,
95-
"Rank of `block_indices_begins` input should be 1, but it is ",
96-
get_input_partial_shape(8).rank().get_length(),
97-
".");
98-
NODE_VALIDATION_CHECK(this,
99-
get_input_element_type(8).is_dynamic() || get_input_element_type(8) == element::i32,
100-
"Element type of `block_indices_begins` input should be i32, but it is ",
101-
get_input_element_type(8),
102-
".");
103-
104-
NODE_VALIDATION_CHECK(
105-
this,
106-
get_input_partial_shape(9).rank().is_dynamic() || get_input_partial_shape(9).rank().get_length() == 0,
107-
"Input `scale` should be a scalar but it has rank ",
108-
get_input_partial_shape(9).rank().get_length(),
109-
".");
110-
NODE_VALIDATION_CHECK(this,
111-
get_input_element_type(9).is_dynamic() || get_input_element_type(9).is_real(),
112-
"Element type of `scale` input should be a floating type, but it is ",
113-
get_input_element_type(9),
114-
".");
115-
NODE_VALIDATION_CHECK(
116-
this,
117-
get_input_partial_shape(10).rank().is_dynamic() || get_input_partial_shape(10).rank().get_length() == 0,
118-
"Input `sliding_window` should be a scalar but it has rank ",
119-
get_input_partial_shape(10).rank().get_length(),
120-
".");
121-
NODE_VALIDATION_CHECK(this,
122-
get_input_element_type(10).is_dynamic() || get_input_element_type(10) == element::i32,
123-
"Element type of `sliding_window` input should be i32, but it is ",
124-
get_input_element_type(10),
125-
".");
126-
127-
NODE_VALIDATION_CHECK(
128-
this,
129-
get_input_partial_shape(11).rank().is_dynamic() || get_input_partial_shape(11).rank().get_length() == 1,
130-
"Rank of `alibi_slopes` input should be 1, but it is ",
131-
get_input_partial_shape(11).rank().get_length(),
132-
".");
133-
NODE_VALIDATION_CHECK(this,
134-
get_input_element_type(11).is_dynamic() || get_input_element_type(11).is_real(),
135-
"Element type of `alibi_slopes` input should be a floating type, but it is ",
136-
get_input_element_type(11),
137-
".");
138-
NODE_VALIDATION_CHECK(
139-
this,
140-
get_input_partial_shape(12).rank().is_dynamic() || get_input_partial_shape(12).rank().get_length() == 0,
141-
"Input `max_context_len` should be a scalar but it has rank ",
142-
get_input_partial_shape(12).rank().get_length(),
143-
".");
144-
NODE_VALIDATION_CHECK(this,
145-
get_input_element_type(12).is_dynamic() || get_input_element_type(12) == element::i32,
146-
"Element type of `max_context_len` input should be i32, but it is ",
147-
get_input_element_type(12),
148-
".");
86+
// format: Node*, input_idx, name, {rank_list}, {type_list}
87+
input_check(this, 0, "query", {2}, {});
88+
input_check(this, 1, "key", {2}, {});
89+
input_check(this, 2, "value", {2}, {});
90+
input_check(this, 3, "key_cache", {2, 3, 4, 5}, {});
91+
input_check(this, 4, "value_cache", {2, 3, 4, 5}, {});
92+
input_check(this, 5, "past_lens", {1}, {element::i32});
93+
input_check(this, 6, "subsequence_begins", {1}, {element::i32});
94+
input_check(this, 7, "block_indices", {1}, {element::i32});
95+
input_check(this, 8, "block_indices_begins", {1}, {element::i32});
96+
input_check(this, 9, "scale", {0}, get_real_types());
97+
input_check(this, 10, "sliding_window", {0}, {element::i32});
98+
input_check(this, 11, "alibi_slopes", {1}, get_real_types());
99+
input_check(this, 12, "max_context_len", {0}, {element::i32});
149100

150101
if (get_input_size() == 16) {
151-
NODE_VALIDATION_CHECK(
152-
this,
153-
get_input_partial_shape(13).rank().is_dynamic() || get_input_partial_shape(13).rank().get_length() == 1,
154-
"Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank ",
155-
get_input_partial_shape(13).rank().get_length(),
156-
".");
157-
NODE_VALIDATION_CHECK(this,
158-
get_input_element_type(13).is_dynamic() || get_input_element_type(13) == element::i32,
159-
"Element type of `rotated_block_indices` input should be i32, but it is ",
160-
get_input_element_type(13),
161-
".");
162-
NODE_VALIDATION_CHECK(
163-
this,
164-
get_input_partial_shape(14).rank().is_dynamic() || get_input_partial_shape(14).rank().get_length() == 2,
165-
"Input `rotation_deltas` should either have rank 2 or be omitted, but it has rank ",
166-
get_input_partial_shape(14).rank().get_length(),
167-
".");
168-
NODE_VALIDATION_CHECK(this,
169-
get_input_element_type(14).is_dynamic() || get_input_element_type(14) == element::i32,
170-
"Element type of `rotation_deltas` input should be i32, but it is ",
171-
get_input_element_type(14),
172-
".");
173-
NODE_VALIDATION_CHECK(
174-
this,
175-
get_input_partial_shape(15).rank().is_dynamic() || get_input_partial_shape(15).rank().get_length() == 2,
176-
"Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ",
177-
get_input_partial_shape(15).rank().get_length(),
178-
".");
179-
NODE_VALIDATION_CHECK(this,
180-
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32 ||
181-
get_input_element_type(15) == element::f16,
182-
"Element type of `rotation_trig_lut` input should be f32 or f16, but it is ",
183-
get_input_element_type(15),
184-
".");
102+
input_check(this, 13, "rotated_block_indices", {1}, {element::i32});
103+
input_check(this, 14, "rotation_deltas", {2}, {element::i32});
104+
input_check(this, 15, "rotation_trig_lut", {2}, {element::f16, element::f32});
185105
}
186106

187107
// value head_size may be not same with key

src/core/tests/type_prop/paged_attention.cpp

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,5 +126,98 @@ TEST(type_prop, paged_attention_static_16_inputs_eviction_per_token) {
126126
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{3, 4}));
127127
}
128128

129+
TEST(type_prop, paged_attention_dynamic_ranks_and_types) {
130+
using namespace ov::op;
131+
const auto dyn = PartialShape::dynamic();
132+
133+
const auto query = std::make_shared<v0::Parameter>(element::dynamic, dyn);
134+
const auto key = std::make_shared<v0::Parameter>(element::dynamic, dyn);
135+
const auto value = std::make_shared<v0::Parameter>(element::dynamic, dyn);
136+
const auto key_cache = std::make_shared<v0::Parameter>(element::dynamic, dyn);
137+
const auto value_cache = std::make_shared<v0::Parameter>(element::dynamic, dyn);
138+
const auto past_lens = std::make_shared<v0::Parameter>(element::dynamic, dyn);
139+
const auto subsequence_begins = std::make_shared<v0::Parameter>(element::dynamic, dyn);
140+
const auto block_indices = std::make_shared<v0::Parameter>(element::dynamic, dyn);
141+
const auto block_indices_begins = std::make_shared<v0::Parameter>(element::dynamic, dyn);
142+
const auto scale = std::make_shared<v0::Parameter>(element::dynamic, dyn);
143+
const auto sliding_window = std::make_shared<v0::Parameter>(element::dynamic, dyn);
144+
const auto alibi_slopes = std::make_shared<v0::Parameter>(element::dynamic, dyn);
145+
const auto max_context_len = std::make_shared<v0::Parameter>(element::dynamic, dyn);
146+
147+
ov::OutputVector args = {query,
148+
key,
149+
value,
150+
key_cache,
151+
value_cache,
152+
past_lens,
153+
subsequence_begins,
154+
block_indices,
155+
block_indices_begins,
156+
scale,
157+
sliding_window,
158+
alibi_slopes,
159+
max_context_len};
160+
161+
EXPECT_NO_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args));
162+
}
163+
164+
TEST(type_prop, paged_attention_invalid_rank_query) {
165+
auto query = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3});
166+
auto key = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
167+
auto value = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
168+
auto dummy = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
169+
auto scalar = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
170+
ov::OutputVector args =
171+
{query, key, value, dummy, dummy, scalar, scalar, scalar, scalar, dummy, scalar, dummy, scalar};
172+
173+
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
174+
}
175+
176+
TEST(type_prop, paged_attention_invalid_type_scale) {
177+
auto scale = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
178+
auto dummy2D = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
179+
auto dummy1D = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{3});
180+
auto dummyScalar = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
181+
182+
ov::OutputVector args = {dummy2D,
183+
dummy2D,
184+
dummy2D,
185+
dummy2D,
186+
dummy2D,
187+
dummy1D,
188+
dummy1D,
189+
dummy1D,
190+
dummy1D,
191+
scale,
192+
dummyScalar,
193+
dummy2D,
194+
dummyScalar};
195+
196+
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
197+
}
198+
199+
TEST(type_prop, paged_attention_invalid_rank_key_cache) {
200+
auto key_cache = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{1});
201+
auto dummy = std::make_shared<op::v0::Parameter>(element::f32, PartialShape{3, 4});
202+
auto dummy1D = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{3});
203+
auto dummyScalar = std::make_shared<op::v0::Parameter>(element::i32, PartialShape{});
204+
205+
ov::OutputVector args = {dummy,
206+
dummy,
207+
dummy,
208+
key_cache,
209+
dummy,
210+
dummy1D,
211+
dummy1D,
212+
dummy1D,
213+
dummy1D,
214+
dummyScalar,
215+
dummyScalar,
216+
dummy,
217+
dummyScalar};
218+
219+
EXPECT_THROW(std::ignore = std::make_shared<op::PagedAttentionExtension>(args), ov::NodeValidationFailure);
220+
}
221+
129222
} // namespace testing
130223
} // namespace ov

0 commit comments

Comments
 (0)