6
6
7
7
#include " dimension_util.hpp"
8
8
#include " itt.hpp"
9
+ #include " openvino/core/validation_util.hpp"
9
10
#include " openvino/op/op.hpp"
10
11
12
+ namespace {
13
+
14
+ // Validates input rank and type for a node input.
15
+ // We consider that dynamic rank/type are always valid case.
16
+ // Empty {} means any rank/type
17
+ inline void input_check (const ov::Node* node,
18
+ size_t idx,
19
+ const std::string_view input_name,
20
+ std::initializer_list<ov::Rank>&& allowed_ranks,
21
+ const std::vector<ov::element::Type>& allowed_types) {
22
+ using namespace ov ;
23
+ using namespace ov ::util;
24
+ using namespace ov ::element;
25
+
26
+ const auto & rank = node->get_input_partial_shape (idx).rank ();
27
+ const auto & tp = node->get_input_element_type (idx);
28
+
29
+ auto rank_check = [&](const Rank& rank) {
30
+ return rank.is_dynamic () || empty (allowed_ranks) || is_rank_compatible_any_of (rank.get_length (), allowed_ranks);
31
+ };
32
+
33
+ auto type_check = [&](const Type& type) {
34
+ auto it = std::find (allowed_types.begin (), allowed_types.end (), tp);
35
+ return type.is_dynamic () || allowed_types.empty () || it != allowed_types.end ();
36
+ };
37
+
38
+ NODE_VALIDATION_CHECK (node,
39
+ rank_check (rank),
40
+ " Rank of `" ,
41
+ input_name,
42
+ " ` input should be in [dynamic, " ,
43
+ join (allowed_ranks),
44
+ " ] list, but it is " ,
45
+ rank,
46
+ " ." );
47
+
48
+ NODE_VALIDATION_CHECK (node,
49
+ type_check (tp),
50
+ " Element type of `" ,
51
+ input_name,
52
+ " ` input should be in [dynamic, " ,
53
+ join (allowed_types),
54
+ " ] list, but it is " ,
55
+ tp,
56
+ " ." );
57
+ }
58
+
59
+ std::vector<ov::element::Type> get_real_types () {
60
+ std::vector<ov::element::Type> real_types;
61
+ for (const auto & type : ov::element::Type::get_known_types ()) {
62
+ if (type->is_real ()) {
63
+ real_types.push_back (*type);
64
+ }
65
+ }
66
+ return real_types;
67
+ }
68
+
69
+ } // namespace
70
+
11
71
namespace ov {
12
72
namespace op {
13
73
@@ -23,165 +83,25 @@ void PagedAttentionExtension::validate_and_infer_types() {
23
83
" PagedAttensionExtension expects 13 or 16 inputs, but it has " ,
24
84
get_input_size ());
25
85
26
- NODE_VALIDATION_CHECK (
27
- this ,
28
- get_input_partial_shape (0 ).rank ().is_dynamic () || get_input_partial_shape (0 ).rank ().get_length () == 2 ,
29
- " Rank of `query` input should be 2, but it is " ,
30
- get_input_partial_shape (0 ).rank ().get_length (),
31
- " ." );
32
- NODE_VALIDATION_CHECK (
33
- this ,
34
- get_input_partial_shape (1 ).rank ().is_dynamic () || get_input_partial_shape (1 ).rank ().get_length () == 2 ,
35
- " Rank of `key` input should be 2, but it is " ,
36
- get_input_partial_shape (1 ).rank ().get_length (),
37
- " ." );
38
- NODE_VALIDATION_CHECK (
39
- this ,
40
- get_input_partial_shape (2 ).rank ().is_dynamic () || get_input_partial_shape (2 ).rank ().get_length () == 2 ,
41
- " Rank of `value` input should be 2, but it is " ,
42
- get_input_partial_shape (2 ).rank ().get_length (),
43
- " ." );
44
-
45
- NODE_VALIDATION_CHECK (
46
- this ,
47
- get_input_partial_shape (3 ).rank ().is_dynamic () || get_input_partial_shape (3 ).rank ().get_length () >= 2 ,
48
- " Rank of `key_cache` input should be at least 2, but it is " ,
49
- get_input_partial_shape (3 ).rank ().get_length (),
50
- " ." );
51
- NODE_VALIDATION_CHECK (
52
- this ,
53
- get_input_partial_shape (4 ).rank ().is_dynamic () || get_input_partial_shape (4 ).rank ().get_length () >= 2 ,
54
- " Rank of `value_cache` input should be at least 2, but it is " ,
55
- get_input_partial_shape (4 ).rank ().get_length (),
56
- " ." );
57
-
58
- NODE_VALIDATION_CHECK (
59
- this ,
60
- get_input_partial_shape (5 ).rank ().is_dynamic () || get_input_partial_shape (5 ).rank ().get_length () == 1 ,
61
- " Rank of `past_lens` input should be 1, but it is " ,
62
- get_input_partial_shape (5 ).rank ().get_length (),
63
- " ." );
64
- NODE_VALIDATION_CHECK (this ,
65
- get_input_element_type (5 ).is_dynamic () || get_input_element_type (5 ) == element::i32 ,
66
- " Element type of `past_lens` input should be i32, but it is " ,
67
- get_input_element_type (5 ),
68
- " ." );
69
- NODE_VALIDATION_CHECK (
70
- this ,
71
- get_input_partial_shape (6 ).rank ().is_dynamic () || get_input_partial_shape (6 ).rank ().get_length () == 1 ,
72
- " Rank of `subsequence_begins` input should be 1, but it is " ,
73
- get_input_partial_shape (6 ).rank ().get_length (),
74
- " ." );
75
- NODE_VALIDATION_CHECK (this ,
76
- get_input_element_type (6 ).is_dynamic () || get_input_element_type (6 ) == element::i32 ,
77
- " Element type of `subsequence_begins` input should be i32, but it is " ,
78
- get_input_element_type (6 ),
79
- " ." );
80
-
81
- NODE_VALIDATION_CHECK (
82
- this ,
83
- get_input_partial_shape (7 ).rank ().is_dynamic () || get_input_partial_shape (7 ).rank ().get_length () == 1 ,
84
- " Rank of `block_indices` input should be 1, but it is " ,
85
- get_input_partial_shape (7 ).rank ().get_length (),
86
- " ." );
87
- NODE_VALIDATION_CHECK (this ,
88
- get_input_element_type (7 ).is_dynamic () || get_input_element_type (7 ) == element::i32 ,
89
- " Element type of `block_indices` input should be i32, but it is " ,
90
- get_input_element_type (7 ),
91
- " ." );
92
- NODE_VALIDATION_CHECK (
93
- this ,
94
- get_input_partial_shape (8 ).rank ().is_dynamic () || get_input_partial_shape (8 ).rank ().get_length () == 1 ,
95
- " Rank of `block_indices_begins` input should be 1, but it is " ,
96
- get_input_partial_shape (8 ).rank ().get_length (),
97
- " ." );
98
- NODE_VALIDATION_CHECK (this ,
99
- get_input_element_type (8 ).is_dynamic () || get_input_element_type (8 ) == element::i32 ,
100
- " Element type of `block_indices_begins` input should be i32, but it is " ,
101
- get_input_element_type (8 ),
102
- " ." );
103
-
104
- NODE_VALIDATION_CHECK (
105
- this ,
106
- get_input_partial_shape (9 ).rank ().is_dynamic () || get_input_partial_shape (9 ).rank ().get_length () == 0 ,
107
- " Input `scale` should be a scalar but it has rank " ,
108
- get_input_partial_shape (9 ).rank ().get_length (),
109
- " ." );
110
- NODE_VALIDATION_CHECK (this ,
111
- get_input_element_type (9 ).is_dynamic () || get_input_element_type (9 ).is_real (),
112
- " Element type of `scale` input should be a floating type, but it is " ,
113
- get_input_element_type (9 ),
114
- " ." );
115
- NODE_VALIDATION_CHECK (
116
- this ,
117
- get_input_partial_shape (10 ).rank ().is_dynamic () || get_input_partial_shape (10 ).rank ().get_length () == 0 ,
118
- " Input `sliding_window` should be a scalar but it has rank " ,
119
- get_input_partial_shape (10 ).rank ().get_length (),
120
- " ." );
121
- NODE_VALIDATION_CHECK (this ,
122
- get_input_element_type (10 ).is_dynamic () || get_input_element_type (10 ) == element::i32 ,
123
- " Element type of `sliding_window` input should be i32, but it is " ,
124
- get_input_element_type (10 ),
125
- " ." );
126
-
127
- NODE_VALIDATION_CHECK (
128
- this ,
129
- get_input_partial_shape (11 ).rank ().is_dynamic () || get_input_partial_shape (11 ).rank ().get_length () == 1 ,
130
- " Rank of `alibi_slopes` input should be 1, but it is " ,
131
- get_input_partial_shape (11 ).rank ().get_length (),
132
- " ." );
133
- NODE_VALIDATION_CHECK (this ,
134
- get_input_element_type (11 ).is_dynamic () || get_input_element_type (11 ).is_real (),
135
- " Element type of `alibi_slopes` input should be a floating type, but it is " ,
136
- get_input_element_type (11 ),
137
- " ." );
138
- NODE_VALIDATION_CHECK (
139
- this ,
140
- get_input_partial_shape (12 ).rank ().is_dynamic () || get_input_partial_shape (12 ).rank ().get_length () == 0 ,
141
- " Input `max_context_len` should be a scalar but it has rank " ,
142
- get_input_partial_shape (12 ).rank ().get_length (),
143
- " ." );
144
- NODE_VALIDATION_CHECK (this ,
145
- get_input_element_type (12 ).is_dynamic () || get_input_element_type (12 ) == element::i32 ,
146
- " Element type of `max_context_len` input should be i32, but it is " ,
147
- get_input_element_type (12 ),
148
- " ." );
86
+ // format: Node*, input_idx, name, {rank_list}, {type_list}
87
+ input_check (this , 0 , " query" , {2 }, {});
88
+ input_check (this , 1 , " key" , {2 }, {});
89
+ input_check (this , 2 , " value" , {2 }, {});
90
+ input_check (this , 3 , " key_cache" , {2 , 3 , 4 , 5 }, {});
91
+ input_check (this , 4 , " value_cache" , {2 , 3 , 4 , 5 }, {});
92
+ input_check (this , 5 , " past_lens" , {1 }, {element::i32 });
93
+ input_check (this , 6 , " subsequence_begins" , {1 }, {element::i32 });
94
+ input_check (this , 7 , " block_indices" , {1 }, {element::i32 });
95
+ input_check (this , 8 , " block_indices_begins" , {1 }, {element::i32 });
96
+ input_check (this , 9 , " scale" , {0 }, get_real_types ());
97
+ input_check (this , 10 , " sliding_window" , {0 }, {element::i32 });
98
+ input_check (this , 11 , " alibi_slopes" , {1 }, get_real_types ());
99
+ input_check (this , 12 , " max_context_len" , {0 }, {element::i32 });
149
100
150
101
if (get_input_size () == 16 ) {
151
- NODE_VALIDATION_CHECK (
152
- this ,
153
- get_input_partial_shape (13 ).rank ().is_dynamic () || get_input_partial_shape (13 ).rank ().get_length () == 1 ,
154
- " Input `rotated_block_indices` should either have rank 1 or be omitted, but it has rank " ,
155
- get_input_partial_shape (13 ).rank ().get_length (),
156
- " ." );
157
- NODE_VALIDATION_CHECK (this ,
158
- get_input_element_type (13 ).is_dynamic () || get_input_element_type (13 ) == element::i32 ,
159
- " Element type of `rotated_block_indices` input should be i32, but it is " ,
160
- get_input_element_type (13 ),
161
- " ." );
162
- NODE_VALIDATION_CHECK (
163
- this ,
164
- get_input_partial_shape (14 ).rank ().is_dynamic () || get_input_partial_shape (14 ).rank ().get_length () == 2 ,
165
- " Input `rotation_deltas` should either have rank 2 or be omitted, but it has rank " ,
166
- get_input_partial_shape (14 ).rank ().get_length (),
167
- " ." );
168
- NODE_VALIDATION_CHECK (this ,
169
- get_input_element_type (14 ).is_dynamic () || get_input_element_type (14 ) == element::i32 ,
170
- " Element type of `rotation_deltas` input should be i32, but it is " ,
171
- get_input_element_type (14 ),
172
- " ." );
173
- NODE_VALIDATION_CHECK (
174
- this ,
175
- get_input_partial_shape (15 ).rank ().is_dynamic () || get_input_partial_shape (15 ).rank ().get_length () == 2 ,
176
- " Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank " ,
177
- get_input_partial_shape (15 ).rank ().get_length (),
178
- " ." );
179
- NODE_VALIDATION_CHECK (this ,
180
- get_input_element_type (15 ).is_dynamic () || get_input_element_type (15 ) == element::f32 ||
181
- get_input_element_type (15 ) == element::f16 ,
182
- " Element type of `rotation_trig_lut` input should be f32 or f16, but it is " ,
183
- get_input_element_type (15 ),
184
- " ." );
102
+ input_check (this , 13 , " rotated_block_indices" , {1 }, {element::i32 });
103
+ input_check (this , 14 , " rotation_deltas" , {2 }, {element::i32 });
104
+ input_check (this , 15 , " rotation_trig_lut" , {2 }, {element::f16 , element::f32 });
185
105
}
186
106
187
107
// value head_size may be not same with key
0 commit comments