Skip to content

Commit 418f33e

Browse files
committed
add shared test
1 parent 5025b41 commit 418f33e

File tree

8 files changed

+381
-10
lines changed

8 files changed

+381
-10
lines changed

src/core/dev_api/openvino/op/gated_delta_net.hpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,24 @@ class OPENVINO_API GatedDeltaNet : public ov::op::Op {
1515
OPENVINO_OP("GatedDeltaNet");
1616

1717
GatedDeltaNet() = default;
18-
18+
struct Config {
19+
bool fuse_qk_l2norm = false;
20+
bool fuse_q_scale = false;
21+
};
1922
GatedDeltaNet(const ov::OutputVector& args);
2023
void validate_and_infer_types() override;
2124
std::shared_ptr<ov::Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;
22-
25+
const Config& get_config() const {
26+
return m_config;
27+
}
28+
void set_config(const Config& config) {
29+
m_config = config;
30+
}
2331
void set_out_type(int index, const ov::element::Type& output_type);
2432

2533
protected:
2634
std::vector<ov::element::Type> m_output_type = {ov::element::dynamic, ov::element::dynamic, ov::element::dynamic};
35+
Config m_config;
2736
};
2837

2938
} // namespace op

src/core/src/op/gated_delta_net.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,15 +76,15 @@ void GatedDeltaNet::validate_and_infer_types() {
7676
input_check(this, 0, "query", {4}, {});
7777
input_check(this, 1, "key", {4}, {});
7878
input_check(this, 2, "value", {4}, {});
79-
input_check(this, 5, "recurrent_state", {4}, {});
80-
input_check(this, 3, "gate", {3}, {});
81-
input_check(this, 4, "beta", {3}, {});
79+
input_check(this, 3, "recurrent_state", {4}, {});
80+
input_check(this, 4, "gate", {3}, {});
81+
input_check(this, 5, "beta", {3}, {});
8282

8383
// value head_size may be not same with key
8484
auto out_ps = get_input_partial_shape(2);
85-
const auto& h_ps= get_input_partial_shape(5);
85+
const auto& h_ps= get_input_partial_shape(3);
8686
set_output_type(0, get_input_element_type(0), out_ps);
87-
set_output_type(1, get_input_element_type(5), h_ps);
87+
set_output_type(1, get_input_element_type(3), h_ps);
8888
}
8989

9090
std::shared_ptr<ov::Node> GatedDeltaNet::clone_with_new_inputs(const ov::OutputVector& new_args) const {

src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void GatedDeltaNet::execute([[maybe_unused]] const dnnl::stream& strm) {
7272
for (size_t i = 0; i < orginInputNumber; i++) {
7373
inputs[i] = getSrcMemoryAtPort(i);
7474
}
75-
std::vector<VectorDims> output_dims = {inputs[0]->getStaticDims(), inputs[5]->getStaticDims()};
75+
std::vector<VectorDims> output_dims = {inputs[0]->getStaticDims(), inputs[3]->getStaticDims()};
7676
redefineOutputMemory(output_dims);
7777

7878
outputs[0] = getDstMemoryAtPort(0);
@@ -85,7 +85,7 @@ void GatedDeltaNet::execute([[maybe_unused]] const dnnl::stream& strm) {
8585
PlainTensor gate(inputs[4]);
8686
PlainTensor beta(inputs[5]);
8787
PlainTensor output_attn(outputs[0]);
88-
PlainTensor output_recurrent_state(inputs[1]);
88+
PlainTensor output_recurrent_state(outputs[1]);
8989
recurrent_linear_attn(query, key, value, recurrent_state, gate, beta, output_attn, output_recurrent_state);
9090
}
9191

src/plugins/intel_cpu/src/nodes/kernels/linear_attn/recurrent_linear_attn.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,8 @@ void recurrent_linear_attn(const ov::intel_cpu::PlainTensor& query,
192192
// scale(init_state, b_g, K_HEAD_DIMS);
193193
multiply_scalar(init_state, init_state, b_g, K_HEAD_DIMS);
194194
float h_k = dot_product(init_state, b_k, K_HEAD_DIMS);
195-
float b_v = v_ptr[i_v + i * H* K_HEAD_DIMS];
195+
// B, T, H, V
196+
float b_v = v_ptr[i_v + i * H * K_HEAD_DIMS];
196197
b_v -= h_k;
197198
// b_v * b_k
198199
b_v *= b_beta;
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "subgraph_tests/gated_delta_net.hpp"
6+
7+
namespace ov {
8+
namespace test {
9+
10+
std::vector<gated_delta_net_params> test_cases = {
11+
{1, 1, 2, 2, 16, ov::element::f32, "CPU"},
12+
{2, 16, 2, 2, 16, ov::element::f32, "CPU"},
13+
{1, 16, 2, 2, 128, ov::element::f32, "CPU"},
14+
};
15+
INSTANTIATE_TEST_SUITE_P(smoke_GatedDeltaNet,
16+
GatedDeltaNet,
17+
::testing::ValuesIn(test_cases),
18+
GatedDeltaNet::getTestCaseName);
19+
} // namespace test
20+
} // namespace ov
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "shared_test_classes/base/ov_subgraph.hpp"
8+
9+
namespace ov {
10+
namespace test {
11+
12+
using gated_delta_net_params = std::tuple<int32_t, // B
13+
int32_t, // T
14+
int32_t, // qk_head_nums
15+
int32_t, // v_head_nums
16+
int32_t, // head_size
17+
ov::element::Type, // infer_precision
18+
std::string // device
19+
>;
20+
21+
class GatedDeltaNet : public testing::WithParamInterface<gated_delta_net_params>, public ov::test::SubgraphBaseTest {
22+
private:
23+
std::shared_ptr<ov::Model> buildLoopedGDN(int32_t batch,
24+
int32_t seq_len,
25+
int32_t qk_head_num,
26+
int32_t v_head_num,
27+
int32_t head_size);
28+
29+
public:
30+
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override;
31+
static std::string getTestCaseName(const testing::TestParamInfo<gated_delta_net_params>& obj);
32+
33+
protected:
34+
void compare(const std::vector<ov::Tensor>& expected, const std::vector<ov::Tensor>& actual) override;
35+
void SetUp() override;
36+
void TearDown() override;
37+
};
38+
39+
} // namespace test
40+
} // namespace ov
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#pragma once
6+
7+
#include "shared_test_classes/subgraph/gated_delta_net.hpp"
8+
9+
namespace ov {
10+
namespace test {
11+
12+
TEST_P(GatedDeltaNet, CompareWithRefs) {
13+
SKIP_IF_CURRENT_TEST_IS_DISABLED();
14+
run();
15+
auto function = compiledModel.get_runtime_model();
16+
};
17+
} // namespace test
18+
} // namespace ov

0 commit comments

Comments
 (0)