Skip to content

Commit a0e8bd0

Browse files
committed
apply review comments
1 parent d7f9105 commit a0e8bd0

File tree

14 files changed

+405
-230
lines changed

14 files changed

+405
-230
lines changed

src/common/transformations/include/transformations/common_optimizations/fuse_gated_delta_net.hpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
#include "openvino/pass/graph_rewrite.hpp"
77
#include "transformations_visibility.hpp"
88

9-
namespace ov {
10-
namespace pass {
9+
namespace ov::pass {
1110

1211
/**
1312
* @ingroup ov_transformation_common_api
@@ -88,5 +87,4 @@ class TRANSFORMATIONS_API GatedDeltaNetFusion : public ov::pass::ModelPass {
8887
bool run_on_model(const std::shared_ptr<ov::Model>& model) override;
8988
};
9089

91-
} // namespace pass
92-
} // namespace ov
90+
} // namespace ov::pass

src/common/transformations/src/transformations/common_optimizations/fuse_gated_delta_net.cpp

Lines changed: 133 additions & 140 deletions
Large diffs are not rendered by default.

src/common/transformations/tests/common_optimizations/fuse_gated_delta_net.cpp

Lines changed: 46 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <gtest/gtest.h>
88

9+
#include <climits>
910
#include <memory>
1011

1112
#include "common_test_utils/ov_test_utils.hpp"
@@ -32,6 +33,7 @@
3233
#include "openvino/op/subtract.hpp"
3334
#include "openvino/op/transpose.hpp"
3435
#include "openvino/op/unsqueeze.hpp"
36+
#include "transformations/convert_precision.hpp"
3537

3638
using namespace testing;
3739
using namespace ov;
@@ -42,23 +44,13 @@ std::shared_ptr<ov::Model> build_looped_gdn(int32_t batch,
4244
int32_t seq_len,
4345
int32_t qk_head_num,
4446
int32_t v_head_num,
45-
int32_t head_size) {
47+
int32_t qk_head_size,
48+
int32_t v_head_size) {
4649
const auto dtype = ov::element::f32;
47-
const ov::Shape qk_shape{static_cast<size_t>(batch),
48-
static_cast<size_t>(seq_len),
49-
static_cast<size_t>(qk_head_num),
50-
static_cast<size_t>(head_size)};
51-
const ov::Shape v_tensor_shape{static_cast<size_t>(batch),
52-
static_cast<size_t>(seq_len),
53-
static_cast<size_t>(v_head_num),
54-
static_cast<size_t>(head_size)};
55-
const ov::Shape gv_shape{static_cast<size_t>(batch),
56-
static_cast<size_t>(seq_len),
57-
static_cast<size_t>(qk_head_num)};
58-
const ov::Shape h_shape{static_cast<size_t>(batch),
59-
static_cast<size_t>(qk_head_num),
60-
static_cast<size_t>(head_size),
61-
static_cast<size_t>(head_size)};
50+
const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size};
51+
const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size};
52+
const ov::PartialShape gv_shape{batch, seq_len, qk_head_num};
53+
const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size};
6254

6355
auto q = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
6456
auto k = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
@@ -178,16 +170,15 @@ std::shared_ptr<ov::Model> build_looped_gdn(int32_t batch,
178170
auto reduce_axis0 = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
179171
auto core_numel = std::make_shared<ov::op::v1::ReduceProd>(core_shape, reduce_axis0, true);
180172
auto state_shape = std::make_shared<ov::op::v3::ShapeOf>(h0);
181-
auto state_numel = std::make_shared<ov::op::v1::ReduceProd>(state_shape, reduce_axis0, true);
182-
auto state_slice_end = std::make_shared<ov::op::v1::Add>(core_numel, state_numel);
183173
auto slice_start = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
184174
auto slice_step = ov::op::v0::Constant::create(ov::element::i64, {1}, {1});
185175
auto slice_axis = ov::op::v0::Constant::create(ov::element::i64, {1}, {0});
176+
auto slice_end_inf = ov::op::v0::Constant::create(ov::element::i64, {1}, {LLONG_MAX});
186177

187178
auto core_slice =
188179
std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, slice_start, core_numel, slice_step, slice_axis);
189180
auto state_slice =
190-
std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, core_numel, state_slice_end, slice_step, slice_axis);
181+
std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, core_numel, slice_end_inf, slice_step, slice_axis);
191182

192183
auto core_restored = std::make_shared<ov::op::v1::Reshape>(core_slice, core_shape, false);
193184
auto state_restored = std::make_shared<ov::op::v1::Reshape>(state_slice, state_shape, false);
@@ -204,23 +195,13 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
204195
int32_t seq_len,
205196
int32_t qk_head_num,
206197
int32_t v_head_num,
207-
int32_t head_size) {
208-
const auto dtype = ov::element::f32;
209-
const ov::Shape qk_shape{static_cast<size_t>(batch),
210-
static_cast<size_t>(seq_len),
211-
static_cast<size_t>(qk_head_num),
212-
static_cast<size_t>(head_size)};
213-
const ov::Shape v_tensor_shape{static_cast<size_t>(batch),
214-
static_cast<size_t>(seq_len),
215-
static_cast<size_t>(v_head_num),
216-
static_cast<size_t>(head_size)};
217-
const ov::Shape gv_shape{static_cast<size_t>(batch),
218-
static_cast<size_t>(seq_len),
219-
static_cast<size_t>(qk_head_num)};
220-
const ov::Shape h_shape{static_cast<size_t>(batch),
221-
static_cast<size_t>(qk_head_num),
222-
static_cast<size_t>(head_size),
223-
static_cast<size_t>(head_size)};
198+
int32_t qk_head_size,
199+
int32_t v_head_size,
200+
ov::element::Type dtype = ov::element::f32) {
201+
const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size};
202+
const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size};
203+
const ov::PartialShape gv_shape{batch, seq_len, qk_head_num};
204+
const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size};
224205

225206
auto q = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
226207
auto k = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
@@ -233,7 +214,8 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
233214
ov::op::GatedDeltaNet::Config cfg;
234215
cfg.fuse_qk_l2norm = true;
235216
cfg.fuse_q_scale = true;
236-
cfg.l2_norm_eps = 1e-6F;
217+
cfg.q_l2_norm_eps = 1e-6F;
218+
cfg.k_l2_norm_eps = 1e-6F;
237219
gdn->set_config(cfg);
238220

239221
return std::make_shared<ov::Model>(ov::OutputVector{gdn->output(0), gdn->output(1)},
@@ -245,14 +227,36 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
245227
TEST_F(TransformationTestsF, GatedDeltaNetFusion_BuildLoopedGDNMode) {
246228
disable_rt_info_check();
247229
disable_result_friendly_names_check();
248-
constexpr int32_t batch = 2;
249-
constexpr int32_t seq_len = 5;
230+
constexpr int32_t batch = -1;
231+
constexpr int32_t seq_len = -1;
250232
constexpr int32_t qk_head_num = 4;
251233
constexpr int32_t v_head_num = 4;
252-
constexpr int32_t head_size = 8;
234+
constexpr int32_t qk_head_size = 8;
235+
constexpr int32_t v_head_size = 16;
253236

254-
model = build_looped_gdn(batch, seq_len, qk_head_num, v_head_num, head_size);
237+
model = build_looped_gdn(batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size);
255238
manager.register_pass<ov::pass::GatedDeltaNetFusion>();
239+
model_ref = build_fused_gdn_ref(batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size);
240+
}
256241

257-
model_ref = build_fused_gdn_ref(batch, seq_len, qk_head_num, v_head_num, head_size);
242+
TEST_F(TransformationTestsF, GatedDeltaNetFusion_BuildLoopedGDNMode_F16) {
243+
disable_rt_info_check();
244+
disable_result_friendly_names_check();
245+
constexpr int32_t batch = -1;
246+
constexpr int32_t seq_len = -1;
247+
constexpr int32_t qk_head_num = 4;
248+
constexpr int32_t v_head_num = 4;
249+
constexpr int32_t qk_head_size = 8;
250+
constexpr int32_t v_head_size = 16;
251+
252+
model = build_looped_gdn(batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size);
253+
manager.register_pass<pass::ConvertPrecision>(ov::element::f32,
254+
ov::element::f16,
255+
type_to_fuse_map{},
256+
true,
257+
true,
258+
false);
259+
manager.register_pass<ov::pass::GatedDeltaNetFusion>();
260+
model_ref =
261+
build_fused_gdn_ref(batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size, ov::element::f16);
258262
}

src/core/dev_api/openvino/op/gated_delta_net.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ class OPENVINO_API GatedDeltaNet : public ov::op::Op {
1818
struct Config {
1919
bool fuse_qk_l2norm = false;
2020
bool fuse_q_scale = false;
21-
float l2_norm_eps = 1e-6F;
21+
float q_l2_norm_eps = 1e-6F;
22+
float k_l2_norm_eps = 1e-6F;
2223
};
2324
GatedDeltaNet(const ov::OutputVector& args);
2425
void validate_and_infer_types() override;

src/core/src/op/gated_delta_net.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@
1212
namespace {
1313

1414
// Validates input rank and type for a node input.
15-
// We consider that dynamic rank/type are always valid case.
16-
// Empty {} means any rank/type
1715
inline void input_check(const ov::Node* node,
1816
size_t idx,
1917
const std::string_view input_name,
@@ -90,19 +88,20 @@ void GatedDeltaNet::validate_and_infer_types() {
9088
const auto v_head_num = value_ps[2];
9189

9290
const auto k_head_size = key_ps[3];
91+
const auto q_head_size = query_ps[3];
9392
const auto v_head_size = value_ps[3];
9493

9594
NODE_VALIDATION_CHECK(this,
96-
q_head_num.compatible(k_head_num),
97-
"The number of heads in query and key should be the same, but got ",
95+
q_head_num.compatible(k_head_num) && q_head_num.compatible(v_head_num),
96+
"The number of heads in query key and value should be the same, but got ",
9897
q_head_num,
9998
" and ",
10099
k_head_num,
101100
".");
102101

103102
NODE_VALIDATION_CHECK(this,
104-
k_head_size.compatible(v_head_size),
105-
"The head size in key and value should be the same, but got ",
103+
k_head_size.compatible(q_head_size),
104+
"The head size in key and query should be the same, but got ",
106105
k_head_size,
107106
" and ",
108107
v_head_size,
@@ -112,8 +111,8 @@ void GatedDeltaNet::validate_and_infer_types() {
112111
const auto beta_head_num = beta_ps[2];
113112

114113
NODE_VALIDATION_CHECK(this,
115-
gate_head_num.compatible(beta_head_num),
116-
"The number of heads in gate and beta should be the same, but got ",
114+
gate_head_num.compatible(beta_head_num) && gate_head_num.compatible(q_head_num),
115+
"The number of heads in gate, beta, and query should be the same, but got ",
117116
gate_head_num,
118117
" and ",
119118
beta_head_num,
@@ -155,7 +154,8 @@ bool GatedDeltaNet::visit_attributes(AttributeVisitor& visitor) {
155154
visitor.start_structure("config");
156155
visitor.on_attribute("fuse_qk_l2norm", m_config.fuse_qk_l2norm);
157156
visitor.on_attribute("fuse_q_scale", m_config.fuse_q_scale);
158-
visitor.on_attribute("l2_norm_eps", m_config.l2_norm_eps);
157+
visitor.on_attribute("q_l2_norm_eps", m_config.q_l2_norm_eps);
158+
visitor.on_attribute("k_l2_norm_eps", m_config.k_l2_norm_eps);
159159
visitor.finish_structure();
160160
return true;
161161
}
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "openvino/op/gated_delta_net.hpp"
6+
7+
#include <gtest/gtest.h>
8+
9+
#include "common_test_utils/test_assertions.hpp"
10+
#include "openvino/openvino.hpp"
11+
12+
using namespace ov;
13+
using namespace testing;
14+
15+
namespace {
16+
17+
std::shared_ptr<op::GatedDeltaNet> make_gdn(const element::Type& et,
18+
const PartialShape& q,
19+
const PartialShape& k,
20+
const PartialShape& v,
21+
const PartialShape& state,
22+
const PartialShape& gate,
23+
const PartialShape& beta) {
24+
auto query = std::make_shared<op::v0::Parameter>(et, q);
25+
auto key = std::make_shared<op::v0::Parameter>(et, k);
26+
auto value = std::make_shared<op::v0::Parameter>(et, v);
27+
auto recurrent_state = std::make_shared<op::v0::Parameter>(et, state);
28+
auto gate_p = std::make_shared<op::v0::Parameter>(et, gate);
29+
auto beta_p = std::make_shared<op::v0::Parameter>(et, beta);
30+
31+
return std::make_shared<op::GatedDeltaNet>(OutputVector{query, key, value, recurrent_state, gate_p, beta_p});
32+
}
33+
34+
} // namespace
35+
36+
TEST(type_prop, gated_delta_net_static_f32) {
37+
const auto op = make_gdn(element::f32,
38+
Shape{2, 5, 4, 8},
39+
Shape{2, 5, 4, 8},
40+
Shape{2, 5, 4, 16},
41+
Shape{2, 4, 8, 16},
42+
Shape{2, 5, 4},
43+
Shape{2, 5, 4});
44+
45+
EXPECT_EQ(op->get_output_size(), 2);
46+
EXPECT_EQ(op->get_output_element_type(0), element::f32);
47+
EXPECT_EQ(op->get_output_element_type(1), element::f32);
48+
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape(Shape{2, 5, 4, 16}));
49+
EXPECT_EQ(op->get_output_partial_shape(1), PartialShape(Shape{2, 4, 8, 16}));
50+
}
51+
52+
TEST(type_prop, gated_delta_net_static_f16) {
53+
const auto op = make_gdn(element::f16,
54+
Shape{2, 5, 4, 8},
55+
Shape{2, 5, 4, 8},
56+
Shape{2, 5, 4, 16},
57+
Shape{2, 4, 8, 16},
58+
Shape{2, 5, 4},
59+
Shape{2, 5, 4});
60+
61+
EXPECT_EQ(op->get_output_element_type(0), element::f16);
62+
EXPECT_EQ(op->get_output_element_type(1), element::f16);
63+
EXPECT_EQ(op->get_output_partial_shape(0), PartialShape(Shape{2, 5, 4, 16}));
64+
EXPECT_EQ(op->get_output_partial_shape(1), PartialShape(Shape{2, 4, 8, 16}));
65+
}
66+
67+
TEST(type_prop, gated_delta_net_partial_shape_infer) {
68+
const auto op = make_gdn(element::bf16,
69+
PartialShape{{1, 4}, -1, {2, 8}, 64},
70+
PartialShape{{1, 4}, -1, {2, 8}, 64},
71+
PartialShape{{1, 4}, -1, {2, 8}, {32, 128}},
72+
PartialShape{{1, 4}, {2, 8}, 64, {32, 128}},
73+
PartialShape{{1, 4}, -1, {2, 8}},
74+
PartialShape{{1, 4}, -1, {2, 8}});
75+
76+
EXPECT_EQ(op->get_output_element_type(0), element::bf16);
77+
EXPECT_EQ(op->get_output_element_type(1), element::bf16);
78+
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{{1, 4}, -1, {2, 8}, {32, 128}}));
79+
EXPECT_EQ(op->get_output_partial_shape(1), (PartialShape{{1, 4}, {2, 8}, 64, {32, 128}}));
80+
}
81+
82+
TEST(type_prop, gated_delta_net_invalid_query_rank) {
83+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
84+
Shape{2, 5, 8},
85+
Shape{2, 5, 4, 8},
86+
Shape{2, 5, 4, 16},
87+
Shape{2, 4, 8, 16},
88+
Shape{2, 5, 4},
89+
Shape{2, 5, 4}),
90+
NodeValidationFailure,
91+
HasSubstr("Rank of `query` input should be in [4] list"));
92+
}
93+
94+
TEST(type_prop, gated_delta_net_invalid_gate_rank) {
95+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
96+
Shape{2, 5, 4, 8},
97+
Shape{2, 5, 4, 8},
98+
Shape{2, 5, 4, 16},
99+
Shape{2, 4, 8, 16},
100+
Shape{2, 5, 4, 1},
101+
Shape{2, 5, 4}),
102+
NodeValidationFailure,
103+
HasSubstr("Rank of `gate` input should be in [3] list"));
104+
}
105+
106+
TEST(type_prop, gated_delta_net_invalid_type) {
107+
OV_EXPECT_THROW(std::ignore = make_gdn(element::i32,
108+
Shape{2, 5, 4, 8},
109+
Shape{2, 5, 4, 8},
110+
Shape{2, 5, 4, 16},
111+
Shape{2, 4, 8, 16},
112+
Shape{2, 5, 4},
113+
Shape{2, 5, 4}),
114+
NodeValidationFailure,
115+
HasSubstr("Element type of `query` input should be in"));
116+
}
117+
118+
TEST(type_prop, gated_delta_net_head_num_mismatch_qkv) {
119+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
120+
Shape{2, 5, 4, 8},
121+
Shape{2, 5, 6, 8},
122+
Shape{2, 5, 4, 16},
123+
Shape{2, 4, 8, 16},
124+
Shape{2, 5, 4},
125+
Shape{2, 5, 4}),
126+
NodeValidationFailure,
127+
HasSubstr("The number of heads in query key and value should be the same"));
128+
}
129+
130+
TEST(type_prop, gated_delta_net_head_size_mismatch_qk) {
131+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
132+
Shape{2, 5, 4, 8},
133+
Shape{2, 5, 4, 32},
134+
Shape{2, 5, 4, 16},
135+
Shape{2, 4, 32, 16},
136+
Shape{2, 5, 4},
137+
Shape{2, 5, 4}),
138+
NodeValidationFailure,
139+
HasSubstr("The head size in key and query should be the same"));
140+
}
141+
142+
TEST(type_prop, gated_delta_net_gate_beta_head_num_mismatch) {
143+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
144+
Shape{2, 5, 4, 8},
145+
Shape{2, 5, 4, 8},
146+
Shape{2, 5, 4, 16},
147+
Shape{2, 4, 8, 16},
148+
Shape{2, 5, 6},
149+
Shape{2, 5, 4}),
150+
NodeValidationFailure,
151+
HasSubstr("The number of heads in gate, beta, and query should be the same"));
152+
}
153+
154+
TEST(type_prop, gated_delta_net_state_shape_mismatch) {
155+
OV_EXPECT_THROW(std::ignore = make_gdn(element::f32,
156+
Shape{2, 5, 4, 8},
157+
Shape{2, 5, 4, 8},
158+
Shape{2, 5, 4, 16},
159+
Shape{2, 4, 8, 32},
160+
Shape{2, 5, 4},
161+
Shape{2, 5, 4}),
162+
NodeValidationFailure,
163+
HasSubstr("The [-1] dim in shape of recurrent_state and value should be the same"));
164+
}

0 commit comments

Comments
 (0)