Skip to content

Commit 2d55899

Browse files
committed
fix clang & apply review comments
1 parent fe9157d commit 2d55899

File tree

5 files changed

+16
-5
lines changed

5 files changed

+16
-5
lines changed

src/common/transformations/include/transformations/common_optimizations/fuse_gated_delta_net.hpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,14 @@ class TRANSFORMATIONS_API RemoveConcatSliceAfterLoop : public ov::pass::MatcherP
2222
/**
2323
* @ingroup ov_transformation_common_api
2424
* @brief Fuses a loop-based gated delta net sub-graph into an internal GatedDeltaNet operation.
25+
*
26+
* Expected Loop body semantics per step:
27+
* 1) Decay recurrent state: `h_decay = h * exp(g)`
28+
* 2) Compute projection and delta: `delta = v - reduce_sum(h_decay * k, axis=-2)`
29+
* 3) Apply beta and update state: `h_new = h_decay + k * unsqueeze(delta * beta, axis=-2)`
30+
* 4) Compute per-step output: `o = reduce_sum(h_new * q, axis=-2, keep_dims=true)` and scatter to time index
31+
*
32+
* The matcher validates this body shape/operation pattern before replacing the Loop with `GatedDeltaNet`.
2533
*/
2634

2735
class TRANSFORMATIONS_API FuseGDNLoop : public ov::pass::MatcherPass {

src/common/transformations/src/transformations/common_optimizations/fuse_gated_delta_net.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ bool matches_linear_attention_loop(const std::shared_ptr<ov::Node>& node) {
7171
auto step_index = any_input();
7272

7373
auto step_index_unsqueeze = wrap_type<v0::Unsqueeze>({step_index, 0});
74-
// auto recurrent_state_f32 = pattern::optional<v0::Convert>({recurrent_state});
7574
auto gate_f32 = pattern::optional<v0::Convert>({gate});
7675

7776
auto exp_gate = wrap_type<v0::Exp>({gate_f32});
@@ -169,8 +168,11 @@ ov::pass::RemoveConcatSliceAfterLoop::RemoveConcatSliceAfterLoop() {
169168
auto slice_attn = pattern::wrap_type<ov::op::v8::Slice>({concat_loop, {0}, any_input(), {1}, {0}});
170169
auto reshape_attn = pattern::wrap_type<v1::Reshape>({slice_attn, any_input()},
171170
pattern::shape_matches("[?, head_num, ?, v_head_size]"));
172-
auto slice_state = pattern::wrap_type<ov::op::v8::Slice>({concat_loop, any_input(), {LLONG_MAX}, {1}, {0}});
173-
auto reshape_state = pattern::wrap_type<v1::Reshape>({slice_state, any_input()});
171+
auto state_end = pattern::wrap_type<v0::Constant>(value_matches("LLONG_MAX") || value_matches("-1"));
172+
auto slice_state = pattern::wrap_type<ov::op::v8::Slice>({concat_loop, any_input(), state_end, {1}, {0}});
173+
auto reshape_state =
174+
pattern::wrap_type<v1::Reshape>({slice_state, any_input()},
175+
pattern::shape_matches("[?, head_num, k_head_size, v_head_size]"));
174176
matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
175177
const auto& pattern_map = m.get_pattern_value_map();
176178
bool changed = false;

src/core/src/op/gated_delta_net.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ void GatedDeltaNet::validate_and_infer_types() {
104104
"The head size in key and query should be the same, but got ",
105105
k_head_size,
106106
" and ",
107-
v_head_size,
107+
q_head_size,
108108
".");
109109

110110
const auto gate_head_num = gate_ps[2];

src/plugins/intel_cpu/src/cpu_parallel.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class CpuParallel {
3838
: parallel_get_max_threads() * m_multiplier;
3939
return num;
4040
}
41-
[[nodiscard]] int get_num_worker_threads() const {
41+
[[nodiscard]] static int get_num_worker_threads() {
4242
return parallel_get_max_threads();
4343
}
4444
void activate() const {

src/plugins/intel_cpu/src/nodes/gated_delta_net.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "graph_context.h"
1515
#include "kernels/linear_attn/recurrent_linear_attn.hpp"
1616
#include "memory_desc/cpu_blocked_memory_desc.h"
17+
#include "memory_desc/cpu_memory_desc.h"
1718
#include "node.h"
1819
#include "onednn/iml_type_mapper.h"
1920
#include "openvino/core/except.hpp"

0 commit comments

Comments
 (0)