Skip to content

Commit bd6fe37

Browse files
committed
[CPU]use state layout of B, H, V, K
1 parent ad86b8d commit bd6fe37

File tree

12 files changed

+218
-202
lines changed

12 files changed

+218
-202
lines changed

src/bindings/python/src/openvino/_pyopenvino/op/__init__.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ class _GatedDeltaNet(openvino._pyopenvino.Node):
189189
Experimental extention for GatedDeltaNet operation. Use with care: no backward compatibility is guaranteed in future releases.
190190
"""
191191
def __init__(self, arg0: collections.abc.Sequence[openvino._pyopenvino.Output]) -> None:
192+
...
192193
class assign(openvino._pyopenvino.Node):
193194
"""
194195
openvino.op.assign wraps ov::op::v6::Assign

src/bindings/python/src/pyopenvino/graph/ops/gated_detla_net.cpp renamed to src/bindings/python/src/pyopenvino/graph/ops/gated_delta_net.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

src/bindings/python/src/pyopenvino/graph/ops/gated_delta_net.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

src/common/transformations/include/transformations/common_optimizations/fuse_gated_delta_net.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
14
#pragma once
25

36
#include "openvino/pass/graph_rewrite.hpp"

src/common/transformations/src/transformations/common_optimizations/fuse_gated_delta_net.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
// Copyright (C) 2018-2026 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
15
#include "transformations/common_optimizations/fuse_gated_delta_net.hpp"
26

37
#include <cstddef>

src/core/dev_api/openvino/op/gated_delta_net.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44
#pragma once

src/core/src/op/gated_delta_net.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

@@ -65,12 +65,9 @@ GatedDeltaNet::GatedDeltaNet(const ov::OutputVector& args) : ov::op::Op(args) {
6565
}
6666

6767
void GatedDeltaNet::validate_and_infer_types() {
68-
OV_OP_SCOPE(LinearAttention_validate_and_infer_types);
68+
OV_OP_SCOPE(GatedDeltaNet_validate_and_infer_types);
6969

70-
NODE_VALIDATION_CHECK(this,
71-
get_input_size() == 6,
72-
"GatedDeltaNet expects 6 inputs, but it has ",
73-
get_input_size());
70+
NODE_VALIDATION_CHECK(this, get_input_size() == 6, "GatedDeltaNet expects 6 inputs, but it has ", get_input_size());
7471

7572
// format: Node*, input_idx, name, {rank_list}, {type_list}
7673
input_check(this, 0, "query", {4}, {});
@@ -82,13 +79,15 @@ void GatedDeltaNet::validate_and_infer_types() {
8279

8380
// value head_size may be not same with key
8481
auto out_ps = get_input_partial_shape(2);
85-
const auto& h_ps= get_input_partial_shape(3);
82+
const auto& h_ps = get_input_partial_shape(3);
8683
set_output_type(0, get_input_element_type(0), out_ps);
8784
set_output_type(1, get_input_element_type(3), h_ps);
8885
}
8986

9087
std::shared_ptr<ov::Node> GatedDeltaNet::clone_with_new_inputs(const ov::OutputVector& new_args) const {
91-
return std::make_shared<GatedDeltaNet>(new_args);
88+
auto cloned = std::make_shared<GatedDeltaNet>(new_args);
89+
cloned->m_config = m_config;
90+
return cloned;
9291
}
9392

9493
void GatedDeltaNet::set_out_type(int index, const ov::element::Type& output_type) {

src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

@@ -16,6 +16,7 @@
1616
#include "cpu_memory.h"
1717
#include "cpu_types.h"
1818
#include "graph_context.h"
19+
#include "kernels/linear_attn/recurrent_linear_attn.hpp"
1920
#include "memory_desc/cpu_memory_desc.h"
2021
#include "node.h"
2122
#include "nodes/common/blocked_desc_creator.h"
@@ -30,7 +31,6 @@
3031
#include "shape_inference/shape_inference_internal_dyn.hpp"
3132
#include "transformations/utils/utils.hpp"
3233
#include "utils/general_utils.h"
33-
#include "kernels/linear_attn/recurrent_linear_attn.hpp"
3434

3535
using namespace ov::Extensions::Cpu;
3636
using namespace dnnl::impl;
@@ -55,8 +55,7 @@ void GatedDeltaNet::initSupportedPrimitiveDescriptors() {
5555
}
5656
std::vector<PortConfigurator> outPortConfigs = {
5757
PortConfigurator{LayoutType::ncsp, dataPrecision, getOutputShapeAtPort(0), false, -1},
58-
PortConfigurator{LayoutType::ncsp, dataPrecision, getOutputShapeAtPort(1), false, -1}
59-
};
58+
PortConfigurator{LayoutType::ncsp, dataPrecision, getOutputShapeAtPort(1), false, -1}};
6059
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
6160
}
6261

@@ -86,11 +85,27 @@ void GatedDeltaNet::execute([[maybe_unused]] const dnnl::stream& strm) {
8685
PlainTensor beta(inputs[5]);
8786
PlainTensor output_attn(outputs[0]);
8887
PlainTensor output_recurrent_state(outputs[1]);
89-
recurrent_linear_attn(query, key, value, recurrent_state, gate, beta, output_attn, output_recurrent_state);
88+
// q, k, h per (B, H, V)
89+
const auto& q_dims = inputs[0]->getStaticDims();
90+
const auto& v_dims = inputs[2]->getStaticDims();
91+
const size_t B = q_dims[0];
92+
const size_t H = q_dims[2];
93+
const size_t K = q_dims[3];
94+
const size_t V = v_dims[3];
95+
temp_buffer.resize<float>({B * H * V * 3 * K});
96+
recurrent_linear_attn(query,
97+
key,
98+
value,
99+
recurrent_state,
100+
gate,
101+
beta,
102+
output_attn,
103+
output_recurrent_state,
104+
temp_buffer);
90105
}
91106

92107
bool GatedDeltaNet::isSupportedOperation(const std::shared_ptr<const ov::Node>& op,
93-
std::string& errorMessage) noexcept {
108+
std::string& errorMessage) noexcept {
94109
return true;
95110
}
96111

src/plugins/intel_cpu/src/nodes/gated_delta_net.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (C) 2018-2025 Intel Corporation
1+
// Copyright (C) 2018-2026 Intel Corporation
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

@@ -42,6 +42,9 @@ class GatedDeltaNet : public Node {
4242
void execute(const dnnl::stream& strm) override;
4343
void createPrimitive() override;
4444
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
45+
46+
private:
47+
PlainTensor temp_buffer;
4548
};
4649

4750
} // namespace ov::intel_cpu::node

0 commit comments

Comments
 (0)