Skip to content

Commit 6311e35

Browse files
authored
[ARM plugin] Fix PRelu broadcast (#493)
1 parent f3df82e commit 6311e35

File tree

4 files changed

+8
-60
lines changed

4 files changed

+8
-60
lines changed

modules/arm_plugin/src/transformations/arm_optimizations.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "transformations/init_node_info.hpp"
88
#include "transformations/decompose_variadic_split.hpp"
99
#include "transformations/common_optimizations/softplus_fusion.hpp"
10+
#include "transformations/common_optimizations/reshape_prelu.hpp"
1011
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
1112
#include "transformations/op_conversions/convert_broadcast3.hpp"
1213
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
@@ -59,7 +60,6 @@
5960
#include "convert_shuffle_channels.hpp"
6061
#include "convert_tile_to_concats.hpp"
6162
#include "convert_transpose_arm.hpp"
62-
#include "convert_prelu.hpp"
6363
#include "convert_gather_arm.hpp"
6464
#include "convert_mvn_arm.hpp"
6565
#include "convert_reduce_multi_axis.hpp"
@@ -180,6 +180,7 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
180180

181181
// Run common optimizations
182182
manager.register_pass<ov::pass::CommonOptimizations>();
183+
manager.register_pass<ov::pass::ReshapePRelu>();
183184
manager.get_pass_config()->disable<ov::pass::ConvertCompressedOnlyToLegacy>();
184185
manager.get_pass_config()->disable<ov::pass::HSwishDecomposition>();
185186
manager.get_pass_config()->disable<ov::pass::LogSoftmaxDecomposition>();
@@ -273,7 +274,6 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
273274
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertReduceSumToPooling>();
274275
manager.register_pass<ngraph::pass::ConstantFolding>();
275276
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::DecomposeMish>();
276-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::BroadcastPRelu>();
277277
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertLogical>();
278278
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertComparison>();
279279
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertTranspose>();

modules/arm_plugin/src/transformations/convert_prelu.cpp

Lines changed: 0 additions & 41 deletions
This file was deleted.

modules/arm_plugin/src/transformations/convert_prelu.hpp

Lines changed: 0 additions & 17 deletions
This file was deleted.

modules/arm_plugin/tests/functional/shared_tests_instances/single_layer_tests/activation.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> basic = {
6464
std::map<std::vector<size_t>, std::vector<std::vector<size_t>>> preluBasic = {
6565
{{1, 50}, {{1}, {50}}},
6666
{{1, 128}, {{1}, {128}}},
67+
68+
// Broadcast check
69+
{{3, 2}, {{1}, {2}, {3, 2}}},
70+
{{3, 2, 5}, {{1}, {2}, {5}, {2, 5}, {3, 1, 5}, {1, 2, 1}, {1, 1, 5}, {3, 1, 1}, {3, 2, 5}}},
71+
{{2, 1, 2}, {{2}, {2, 1, 1}}},
72+
{{3, 2, 5, 7}, {{1}, {7}, {2}, {5, 7}, {2, 5, 7}, {2, 1, 1}, {1, 2, 1, 1}, {3, 2, 1, 1}, {3, 2, 5, 7}}},
6773
};
6874

6975
const auto basicCases = ::testing::Combine(

0 commit comments

Comments
 (0)