Skip to content

Commit 5b680ae

Browse files
authored
[Snippets][CPU] Fix LoopManager::get_common_outer_loops indices, remove extra function overloads (#32291)
### Details: Initialize common loop prefix counter with the first expression’s loop depth so shared outer loops are preserved when scanning multiple expressions ### Tickets: - N/A
1 parent 13d6f45 commit 5b680ae

File tree

5 files changed

+65
-20
lines changed

5 files changed

+65
-20
lines changed

src/common/snippets/include/snippets/lowered/loop_manager.hpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,6 @@ class LoopManager {
6565
* @return vector of outer loop IDs
6666
*/
6767
static std::vector<size_t> get_outer_expr_loops(const ExpressionPtr& expr, size_t loop_id);
68-
/**
69-
* @brief Get loop IDs of expression that are common outer (upper) than `loop_id`
70-
* @param lhs the first expression
71-
* @param rhs the second expression
72-
* @return vector of common outer loop IDs
73-
*/
74-
static std::vector<size_t> get_common_outer_loops(const ExpressionPtr& lhs, const ExpressionPtr& rhs);
7568
/**
7669
* @brief Get common outer loop IDs of expression set
7770
* @param exprs vector of expressions

src/common/snippets/src/lowered/loop_manager.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,6 @@ std::vector<size_t> LoopManager::get_outer_expr_loops(const ExpressionPtr& expr,
6767
return {loop_ids.cbegin(), it};
6868
}
6969

70-
std::vector<size_t> LoopManager::get_common_outer_loops(const ExpressionPtr& lhs, const ExpressionPtr& rhs) {
71-
const auto& rhs_ids = rhs->get_loop_ids();
72-
const auto& lhs_ids = lhs->get_loop_ids();
73-
size_t idx = 0;
74-
while (idx < std::min(rhs_ids.size(), lhs_ids.size()) && rhs_ids[idx] == lhs_ids[idx]) {
75-
idx++;
76-
}
77-
return {rhs_ids.cbegin(), rhs_ids.cbegin() + idx};
78-
}
79-
8070
std::vector<size_t> LoopManager::get_common_outer_loops(const std::vector<ExpressionPtr>& exprs) {
8171
OPENVINO_ASSERT(!exprs.empty(), "Failed to find common outer loops for set of expressions: there no expressions");
8272

@@ -89,7 +79,7 @@ std::vector<size_t> LoopManager::get_common_outer_loops(const std::vector<Expres
8979
};
9080

9181
const auto& first_loop_ids = exprs.front()->get_loop_ids();
92-
size_t common_idx = 0;
82+
size_t common_idx = first_loop_ids.size();
9383
for (size_t i = 1; i < exprs.size(); ++i) {
9484
common_idx = std::min(common_idx, get_first_diff_id_idx(first_loop_ids, exprs[i]->get_loop_ids()));
9585
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "common_test_utils/ov_test_utils.hpp"
6+
7+
#include <gmock/gmock.h>
8+
9+
#include "snippets/lowered/expression.hpp"
10+
#include "snippets/lowered/loop_manager.hpp"
11+
12+
namespace ov {
13+
namespace test {
14+
namespace snippets {
15+
namespace {
16+
17+
using ov::snippets::lowered::Expression;
18+
using ov::snippets::lowered::ExpressionPtr;
19+
using ov::snippets::lowered::LoopManager;
20+
using ::testing::ElementsAre;
21+
using ::testing::IsEmpty;
22+
23+
ExpressionPtr make_expression_with_loops(const std::vector<size_t>& loop_ids) {
24+
auto expr = std::make_shared<Expression>();
25+
expr->set_loop_ids(loop_ids);
26+
return expr;
27+
}
28+
29+
TEST(LoopManagerTests, SingleExpressionKeepsAllLoops) {
30+
auto expr = make_expression_with_loops({1, 2, 3});
31+
32+
LoopManager manager;
33+
const auto loops = manager.get_common_outer_loops({expr});
34+
35+
EXPECT_THAT(loops, ElementsAre(1, 2, 3));
36+
}
37+
38+
TEST(LoopManagerTests, MultipleExpressionsShrinkToCommonPrefix) {
39+
auto expr0 = make_expression_with_loops({0, 1, 2});
40+
auto expr1 = make_expression_with_loops({0, 1, 3});
41+
auto expr2 = make_expression_with_loops({0, 4});
42+
43+
LoopManager manager;
44+
const auto loops = manager.get_common_outer_loops({expr0, expr1, expr2});
45+
46+
EXPECT_THAT(loops, ElementsAre(0));
47+
}
48+
49+
TEST(LoopManagerTests, ExpressionsWithoutCommonLoops) {
50+
auto expr0 = make_expression_with_loops({5, 6});
51+
auto expr1 = make_expression_with_loops({7, 8});
52+
53+
LoopManager manager;
54+
const auto loops = manager.get_common_outer_loops({expr0, expr1});
55+
56+
EXPECT_THAT(loops, IsEmpty());
57+
}
58+
59+
} // namespace
60+
} // namespace snippets
61+
} // namespace test
62+
} // namespace ov

src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/lowered/insert_gemm_copy_buffers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ bool InsertGemmCopyBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt begin
4040
BufferExpressionPtr buffer_expr =
4141
factory->build<ov::intel_cpu::aarch64::RepackedWeightsBufferExpression>(buffer_op, {copy_b_out});
4242
linear_ir.insert_expr(buffer_expr,
43-
LoopManager::get_common_outer_loops(copy_b_expr, copy_b_consumers.begin()->get_expr()),
43+
LoopManager::get_common_outer_loops({copy_b_expr, copy_b_consumers.begin()->get_expr()}),
4444
true,
4545
insertion_pos,
4646
{copy_b_consumers});

src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/insert_brgemm_copy_buffers.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ bool InsertBrgemmCopyBuffers::run(LinearIR& linear_ir, LinearIR::constExprIt beg
5151
}
5252
return linear_ir.insert_expr(
5353
buffer_expr,
54-
LoopManager::get_common_outer_loops(copy_b_expr, copy_b_consumers.begin()->get_expr()),
54+
LoopManager::get_common_outer_loops({copy_b_expr, copy_b_consumers.begin()->get_expr()}),
5555
true,
5656
insertion_pos,
5757
{copy_b_consumers});

0 commit comments

Comments
 (0)