Skip to content

Commit 60a0acf

Browse files
committed
fix test and op spec
1 parent a3d1eff commit 60a0acf

File tree

4 files changed

+220
-72
lines changed

4 files changed

+220
-72
lines changed

src/bindings/python/src/openvino/_pyopenvino/op/__init__.pyi

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,7 @@ class _PagedAttentionExtension(openvino._pyopenvino.Node):
182182
Experimental extention for PagedAttention operation. Use with care: no backward compatibility is guaranteed in future releases.
183183
"""
184184
def __init__(self, arg0: collections.abc.Sequence[openvino._pyopenvino.Output]) -> None:
185-
...
186-
185+
...
187186
class _GatedDeltaNet(openvino._pyopenvino.Node):
188187
"""
189188
Experimental extention for GatedDeltaNet operation. Use with care: no backward compatibility is guaranteed in future releases.

src/core/src/op/gated_delta_net.cpp

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,71 @@ void GatedDeltaNet::validate_and_infer_types() {
7777
input_check(this, 4, "gate", {3}, {});
7878
input_check(this, 5, "beta", {3}, {});
7979

80-
// value head_size may be not same with key
80+
// batch, seq_len, head_num, head_size
81+
const auto& query_ps = get_input_partial_shape(0);
82+
const auto& key_ps = get_input_partial_shape(1);
83+
const auto& value_ps = get_input_partial_shape(2);
84+
const auto& state_ps = get_input_partial_shape(3);
85+
const auto& gate_ps = get_input_partial_shape(4);
86+
const auto& beta_ps = get_input_partial_shape(5);
87+
88+
const auto q_head_num = query_ps[2];
89+
const auto k_head_num = key_ps[2];
90+
const auto v_head_num = value_ps[2];
91+
92+
const auto k_head_size = key_ps[3];
93+
const auto v_head_size = value_ps[3];
94+
95+
NODE_VALIDATION_CHECK(this, q_head_num.is_static() && k_head_num.is_static() && q_head_num.get_length() == k_head_num.get_length(),
96+
"The number of heads in query and key should be the same, but got ",
97+
q_head_num,
98+
" and ",
99+
k_head_num,
100+
".");
101+
102+
NODE_VALIDATION_CHECK(this, k_head_size.is_static() && v_head_size.is_static() && k_head_size.get_length() == v_head_size.get_length(),
103+
"The head size in key and value should be the same, but got ",
104+
k_head_size,
105+
" and ",
106+
v_head_size,
107+
".");
108+
109+
const auto gate_head_num = gate_ps[2];
110+
const auto beta_head_num = beta_ps[2];
111+
112+
NODE_VALIDATION_CHECK(this, gate_head_num.is_static() && beta_head_num.is_static() && gate_head_num.get_length() == beta_head_num.get_length(),
113+
"The number of heads in gate and beta should be the same, but got ",
114+
gate_head_num,
115+
" and ",
116+
beta_head_num,
117+
".");
118+
119+
// [batch, v_head_nums, v_head_size, k_head_size]
120+
const auto state_head_num = state_ps[1];
121+
const auto state_hidden_size_0 = state_ps[2];
122+
const auto state_hidden_size_1 = state_ps[3];
123+
NODE_VALIDATION_CHECK(this, state_head_num.is_static() && state_head_num.get_length() == v_head_num.get_length(),
124+
"The number of heads in recurrent_state and value should be the same, but got ",
125+
state_head_num,
126+
" and ",
127+
v_head_num,
128+
".");
129+
NODE_VALIDATION_CHECK(this, state_hidden_size_0.is_static() && state_hidden_size_0.get_length() == v_head_size.get_length(),
130+
"The [-2] dim in shape of recurrent_state and value should be the same, but got ",
131+
state_hidden_size_0,
132+
" and ",
133+
v_head_size,
134+
".");
135+
NODE_VALIDATION_CHECK(this, state_hidden_size_1.is_static() && state_hidden_size_1.get_length() == k_head_size.get_length(),
136+
"The [-1] dim in shape of recurrent_state and key should be the same, but got ",
137+
state_hidden_size_1,
138+
" and ",
139+
k_head_size,
140+
".");
141+
// output has the same shape and type as input value, output state has the same shape and type as input recurrent_state
81142
auto out_ps = get_input_partial_shape(2);
82143
const auto& h_ps = get_input_partial_shape(3);
83-
set_output_type(0, get_input_element_type(0), out_ps);
144+
set_output_type(0, get_input_element_type(2), out_ps);
84145
set_output_type(1, get_input_element_type(3), h_ps);
85146
}
86147

src/tests/functional/plugin/shared/include/subgraph_tests/gated_delta_net.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,26 @@
99
namespace ov {
1010
namespace test {
1111

12+
inline void CheckNumberOfNodesWithType(std::shared_ptr<const ov::Model> function,
13+
const std::unordered_set<std::string>& nodeTypes,
14+
size_t expectedCount) {
15+
ASSERT_NE(nullptr, function);
16+
int num_ops = 0;
17+
for (const auto& node : function->get_ordered_ops()) {
18+
const auto& rt_info = node->get_rt_info();
19+
const auto layer_type = rt_info.find("layerType")->second.as<std::string>();
20+
if (nodeTypes.count(layer_type)) {
21+
num_ops++;
22+
}
23+
}
24+
ASSERT_EQ(num_ops, expectedCount);
25+
}
26+
1227
TEST_P(GatedDeltaNet, CompareWithRefs) {
1328
SKIP_IF_CURRENT_TEST_IS_DISABLED();
1429
run();
1530
auto function = compiledModel.get_runtime_model();
31+
CheckNumberOfNodesWithType(function, {"GatedDeltaNet"}, 1);
1632
};
1733
} // namespace test
1834
} // namespace ov

0 commit comments

Comments
 (0)