|
| 1 | +// Copyright (C) 2018-2025 Intel Corporation |
| 2 | +// SPDX-License-Identifier: Apache-2.0 |
| 3 | +// |
| 4 | + |
| 5 | +#include <gtest/gtest.h> |
| 6 | + |
| 7 | +#include "openvino/op/result.hpp" |
| 8 | +#include "snippets/op/scalar.hpp" |
| 9 | +#include "snippets/op/reshape.hpp" |
| 10 | +#include "snippets/op/online_softmax.hpp" |
| 11 | +#include "openvino/op/parameter.hpp" |
| 12 | + |
| 13 | +#include "snippets/lowered/linear_ir.hpp" |
| 14 | +#include "snippets/lowered/port_descriptor.hpp" |
| 15 | + |
| 16 | +#include "lir_test_utils.hpp" |
| 17 | + |
| 18 | +namespace ov { |
| 19 | +namespace test { |
| 20 | +namespace snippets { |
| 21 | + |
| 22 | +using namespace ov::snippets; |
| 23 | + |
| 24 | +TEST(LinearIRReplaceWithNode, PreservesPerOutputDescriptors) { |
| 25 | + const auto param = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::Shape{1, 6}); |
| 26 | + const auto op = std::make_shared<ov::snippets::op::OnlineSoftmax>(param); |
| 27 | + const auto result_0 = std::make_shared<ov::op::v0::Result>(op->output(0)); |
| 28 | + const auto result_1 = std::make_shared<ov::op::v0::Result>(op->output(1)); |
| 29 | + const auto model = std::make_shared<ov::Model>(ov::OutputVector{result_0, result_1}, ov::ParameterVector{param}); |
| 30 | + |
| 31 | + auto factory = std::make_shared<ov::snippets::IShapeInferSnippetsFactory>(); |
| 32 | + ov::snippets::lowered::LinearIR linear_ir(model, factory); |
| 33 | + |
| 34 | + const auto& online_sm_expr = linear_ir.get_expr_by_node(op); |
| 35 | + ASSERT_NE(nullptr, online_sm_expr); |
| 36 | + |
| 37 | + const std::vector<VectorDims> subtensors = {get_default_subtensor(2), VectorDims{1, 3}, VectorDims{1, 6}}; |
| 38 | + const std::vector<VectorDims> layouts = {VectorDims{}, VectorDims{0, 1}, VectorDims{1, 0}}; |
| 39 | + init_expr_descriptors(online_sm_expr, subtensors, layouts); |
| 40 | + |
| 41 | + const auto expected_desc_0 = online_sm_expr->get_output_port_descriptor(0)->clone(); |
| 42 | + const auto expected_desc_1 = online_sm_expr->get_output_port_descriptor(1)->clone(); |
| 43 | + |
| 44 | + ASSERT_NE(expected_desc_0->get_subtensor(), expected_desc_1->get_subtensor()); |
| 45 | + ASSERT_NE(expected_desc_0->get_layout(), expected_desc_1->get_layout()); |
| 46 | + |
| 47 | + const auto new_node = std::make_shared<ov::snippets::op::OnlineSoftmax>(param); |
| 48 | + linear_ir.replace_with_node({online_sm_expr}, new_node); |
| 49 | + |
| 50 | + const auto& new_expr = linear_ir.get_expr_by_node(new_node); |
| 51 | + ASSERT_NE(nullptr, new_expr); |
| 52 | + |
| 53 | + const auto& new_desc_0 = new_expr->get_output_port_descriptor(0); |
| 54 | + const auto& new_desc_1 = new_expr->get_output_port_descriptor(1); |
| 55 | + |
| 56 | + EXPECT_EQ(new_desc_0->get_subtensor(), expected_desc_0->get_subtensor()); |
| 57 | + EXPECT_EQ(new_desc_1->get_subtensor(), expected_desc_1->get_subtensor()); |
| 58 | + EXPECT_EQ(new_desc_0->get_layout(), expected_desc_0->get_layout()); |
| 59 | + EXPECT_EQ(new_desc_1->get_layout(), expected_desc_1->get_layout()); |
| 60 | +} |
| 61 | + |
| 62 | +} // namespace snippets |
| 63 | +} // namespace test |
| 64 | +} // namespace ov |
0 commit comments