Skip to content

Commit 4085562

Browse files
authored
[Snippets] Fix port descriptor index in LinearIR::replace_with_node (#32292)
### Details: Ensure LinearIR::replace_with_node clones correct port descriptor instead of reusing the descriptor under the index 0 unconditionally ### Tickets: - N/A
1 parent 93f9e80 commit 4085562

File tree

2 files changed

+65
-1
lines changed

2 files changed

+65
-1
lines changed

src/common/snippets/src/lowered/linear_ir.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ LinearIR::exprIt LinearIR::replace_with_node(const std::vector<ExpressionPtr>& o
438438
for (size_t i = 0; i < new_node->get_output_size(); ++i) {
439439
snippets::lowered::PortDescriptorUtils::set_port_descriptor_ptr(
440440
new_node->output(i),
441-
last_old_expr->get_output_port_descriptor(0)->clone());
441+
last_old_expr->get_output_port_descriptor(i)->clone());
442442
}
443443

444444
const auto new_expr = create_expression(new_node, new_inputs, loop_ids, false);
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright (C) 2018-2025 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include <gtest/gtest.h>
6+
7+
#include "openvino/op/result.hpp"
8+
#include "snippets/op/scalar.hpp"
9+
#include "snippets/op/reshape.hpp"
10+
#include "snippets/op/online_softmax.hpp"
11+
#include "openvino/op/parameter.hpp"
12+
13+
#include "snippets/lowered/linear_ir.hpp"
14+
#include "snippets/lowered/port_descriptor.hpp"
15+
16+
#include "lir_test_utils.hpp"
17+
18+
namespace ov {
19+
namespace test {
20+
namespace snippets {
21+
22+
using namespace ov::snippets;
23+
24+
TEST(LinearIRReplaceWithNode, PreservesPerOutputDescriptors) {
25+
const auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 6});
26+
const auto op = std::make_shared<ov::snippets::op::OnlineSoftmax>(param);
27+
const auto result_0 = std::make_shared<ov::op::v0::Result>(op->output(0));
28+
const auto result_1 = std::make_shared<ov::op::v0::Result>(op->output(1));
29+
const auto model = std::make_shared<ov::Model>(ov::OutputVector{result_0, result_1}, ov::ParameterVector{param});
30+
31+
auto factory = std::make_shared<ov::snippets::IShapeInferSnippetsFactory>();
32+
ov::snippets::lowered::LinearIR linear_ir(model, factory);
33+
34+
const auto& online_sm_expr = linear_ir.get_expr_by_node(op);
35+
ASSERT_NE(nullptr, online_sm_expr);
36+
37+
const std::vector<VectorDims> subtensors = {get_default_subtensor(2), VectorDims{1, 3}, VectorDims{1, 6}};
38+
const std::vector<VectorDims> layouts = {VectorDims{}, VectorDims{0, 1}, VectorDims{1, 0}};
39+
init_expr_descriptors(online_sm_expr, subtensors, layouts);
40+
41+
const auto expected_desc_0 = online_sm_expr->get_output_port_descriptor(0)->clone();
42+
const auto expected_desc_1 = online_sm_expr->get_output_port_descriptor(1)->clone();
43+
44+
ASSERT_NE(expected_desc_0->get_subtensor(), expected_desc_1->get_subtensor());
45+
ASSERT_NE(expected_desc_0->get_layout(), expected_desc_1->get_layout());
46+
47+
const auto new_node = std::make_shared<ov::snippets::op::OnlineSoftmax>(param);
48+
linear_ir.replace_with_node({online_sm_expr}, new_node);
49+
50+
const auto& new_expr = linear_ir.get_expr_by_node(new_node);
51+
ASSERT_NE(nullptr, new_expr);
52+
53+
const auto& new_desc_0 = new_expr->get_output_port_descriptor(0);
54+
const auto& new_desc_1 = new_expr->get_output_port_descriptor(1);
55+
56+
EXPECT_EQ(new_desc_0->get_subtensor(), expected_desc_0->get_subtensor());
57+
EXPECT_EQ(new_desc_1->get_subtensor(), expected_desc_1->get_subtensor());
58+
EXPECT_EQ(new_desc_0->get_layout(), expected_desc_0->get_layout());
59+
EXPECT_EQ(new_desc_1->get_layout(), expected_desc_1->get_layout());
60+
}
61+
62+
} // namespace snippets
63+
} // namespace test
64+
} // namespace ov

0 commit comments

Comments
 (0)