66
77#include < gtest/gtest.h>
88
9+ #include < climits>
910#include < memory>
1011
1112#include " common_test_utils/ov_test_utils.hpp"
3233#include " openvino/op/subtract.hpp"
3334#include " openvino/op/transpose.hpp"
3435#include " openvino/op/unsqueeze.hpp"
36+ #include " transformations/convert_precision.hpp"
3537
3638using namespace testing ;
3739using namespace ov ;
@@ -42,23 +44,13 @@ std::shared_ptr<ov::Model> build_looped_gdn(int32_t batch,
4244 int32_t seq_len,
4345 int32_t qk_head_num,
4446 int32_t v_head_num,
45- int32_t head_size) {
47+ int32_t qk_head_size,
48+ int32_t v_head_size) {
4649 const auto dtype = ov::element::f32 ;
47- const ov::Shape qk_shape{static_cast <size_t >(batch),
48- static_cast <size_t >(seq_len),
49- static_cast <size_t >(qk_head_num),
50- static_cast <size_t >(head_size)};
51- const ov::Shape v_tensor_shape{static_cast <size_t >(batch),
52- static_cast <size_t >(seq_len),
53- static_cast <size_t >(v_head_num),
54- static_cast <size_t >(head_size)};
55- const ov::Shape gv_shape{static_cast <size_t >(batch),
56- static_cast <size_t >(seq_len),
57- static_cast <size_t >(qk_head_num)};
58- const ov::Shape h_shape{static_cast <size_t >(batch),
59- static_cast <size_t >(qk_head_num),
60- static_cast <size_t >(head_size),
61- static_cast <size_t >(head_size)};
50+ const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size};
51+ const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size};
52+ const ov::PartialShape gv_shape{batch, seq_len, qk_head_num};
53+ const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size};
6254
6355 auto q = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
6456 auto k = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
@@ -178,16 +170,15 @@ std::shared_ptr<ov::Model> build_looped_gdn(int32_t batch,
178170 auto reduce_axis0 = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
179171 auto core_numel = std::make_shared<ov::op::v1::ReduceProd>(core_shape, reduce_axis0, true );
180172 auto state_shape = std::make_shared<ov::op::v3::ShapeOf>(h0);
181- auto state_numel = std::make_shared<ov::op::v1::ReduceProd>(state_shape, reduce_axis0, true );
182- auto state_slice_end = std::make_shared<ov::op::v1::Add>(core_numel, state_numel);
183173 auto slice_start = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
184174 auto slice_step = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {1 });
185175 auto slice_axis = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {0 });
176+ auto slice_end_inf = ov::op::v0::Constant::create (ov::element::i64 , {1 }, {LLONG_MAX});
186177
187178 auto core_slice =
188179 std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, slice_start, core_numel, slice_step, slice_axis);
189180 auto state_slice =
190- std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, core_numel, state_slice_end , slice_step, slice_axis);
181+ std::make_shared<ov::op::v8::Slice>(packed_loop_outputs, core_numel, slice_end_inf , slice_step, slice_axis);
191182
192183 auto core_restored = std::make_shared<ov::op::v1::Reshape>(core_slice, core_shape, false );
193184 auto state_restored = std::make_shared<ov::op::v1::Reshape>(state_slice, state_shape, false );
@@ -204,23 +195,13 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
204195 int32_t seq_len,
205196 int32_t qk_head_num,
206197 int32_t v_head_num,
207- int32_t head_size) {
208- const auto dtype = ov::element::f32 ;
209- const ov::Shape qk_shape{static_cast <size_t >(batch),
210- static_cast <size_t >(seq_len),
211- static_cast <size_t >(qk_head_num),
212- static_cast <size_t >(head_size)};
213- const ov::Shape v_tensor_shape{static_cast <size_t >(batch),
214- static_cast <size_t >(seq_len),
215- static_cast <size_t >(v_head_num),
216- static_cast <size_t >(head_size)};
217- const ov::Shape gv_shape{static_cast <size_t >(batch),
218- static_cast <size_t >(seq_len),
219- static_cast <size_t >(qk_head_num)};
220- const ov::Shape h_shape{static_cast <size_t >(batch),
221- static_cast <size_t >(qk_head_num),
222- static_cast <size_t >(head_size),
223- static_cast <size_t >(head_size)};
198+ int32_t qk_head_size,
199+ int32_t v_head_size,
200+ ov::element::Type dtype = ov::element::f32 ) {
201+ const ov::PartialShape qk_shape{batch, seq_len, qk_head_num, qk_head_size};
202+ const ov::PartialShape v_tensor_shape{batch, seq_len, v_head_num, v_head_size};
203+ const ov::PartialShape gv_shape{batch, seq_len, qk_head_num};
204+ const ov::PartialShape h_shape{batch, qk_head_num, qk_head_size, v_head_size};
224205
225206 auto q = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
226207 auto k = std::make_shared<ov::op::v0::Parameter>(dtype, qk_shape);
@@ -233,7 +214,8 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
233214 ov::op::GatedDeltaNet::Config cfg;
234215 cfg.fuse_qk_l2norm = true ;
235216 cfg.fuse_q_scale = true ;
236- cfg.l2_norm_eps = 1e-6F ;
217+ cfg.q_l2_norm_eps = 1e-6F ;
218+ cfg.k_l2_norm_eps = 1e-6F ;
237219 gdn->set_config (cfg);
238220
239221 return std::make_shared<ov::Model>(ov::OutputVector{gdn->output (0 ), gdn->output (1 )},
@@ -245,14 +227,36 @@ std::shared_ptr<ov::Model> build_fused_gdn_ref(int32_t batch,
245227TEST_F (TransformationTestsF, GatedDeltaNetFusion_BuildLoopedGDNMode) {
246228 disable_rt_info_check ();
247229 disable_result_friendly_names_check ();
248- constexpr int32_t batch = 2 ;
249- constexpr int32_t seq_len = 5 ;
230+ constexpr int32_t batch = - 1 ;
231+ constexpr int32_t seq_len = - 1 ;
250232 constexpr int32_t qk_head_num = 4 ;
251233 constexpr int32_t v_head_num = 4 ;
252- constexpr int32_t head_size = 8 ;
234+ constexpr int32_t qk_head_size = 8 ;
235+ constexpr int32_t v_head_size = 16 ;
253236
254- model = build_looped_gdn (batch, seq_len, qk_head_num, v_head_num, head_size );
237+ model = build_looped_gdn (batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size );
255238 manager.register_pass <ov::pass::GatedDeltaNetFusion>();
239+ model_ref = build_fused_gdn_ref (batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size);
240+ }
256241
257- model_ref = build_fused_gdn_ref (batch, seq_len, qk_head_num, v_head_num, head_size);
242+ TEST_F (TransformationTestsF, GatedDeltaNetFusion_BuildLoopedGDNMode_F16) {
243+ disable_rt_info_check ();
244+ disable_result_friendly_names_check ();
245+ constexpr int32_t batch = -1 ;
246+ constexpr int32_t seq_len = -1 ;
247+ constexpr int32_t qk_head_num = 4 ;
248+ constexpr int32_t v_head_num = 4 ;
249+ constexpr int32_t qk_head_size = 8 ;
250+ constexpr int32_t v_head_size = 16 ;
251+
252+ model = build_looped_gdn (batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size);
253+ manager.register_pass <pass::ConvertPrecision>(ov::element::f32 ,
254+ ov::element::f16 ,
255+ type_to_fuse_map{},
256+ true ,
257+ true ,
258+ false );
259+ manager.register_pass <ov::pass::GatedDeltaNetFusion>();
260+ model_ref =
261+ build_fused_gdn_ref (batch, seq_len, qk_head_num, v_head_num, qk_head_size, v_head_size, ov::element::f16 );
258262}
0 commit comments