Skip to content

Commit f3df82e

Browse files
authored
[ARM plugin] Update transformation set (#496)
* Part 1 123 * Part 2 * Part 3
1 parent 69212bf commit f3df82e

File tree

4 files changed

+57
-41
lines changed

4 files changed

+57
-41
lines changed

modules/arm_plugin/src/arm_converter/arm_converter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,11 @@ Converter::Converter(const std::shared_ptr<const ov::Model> model, const Configu
6464
Register<opset::Clamp>();
6565
Register<opset::Sqrt>();
6666
Register<opset::Elu>();
67+
Register<ngraph::op::v0::Gelu>();
6768
Register<opset::Gelu>();
6869
Register<opset::ArmTranspose>();
6970
Register<opset::Softmax>();
71+
Register<opset::LogSoftmax>();
7072
Register<opset::ArmSplit>();
7173
Register<opset::LRN>();
7274
Register<opset::Minimum>();

modules/arm_plugin/src/arm_converter/arm_converter_activation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,11 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Gelu& node
120120
return ConvertActivation(node, info, this);
121121
}
122122

123+
template<> Converter::Conversion::Ptr Converter::Convert(const ngraph::op::v0::Gelu& node) {
124+
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::GELU);
125+
return ConvertActivation(node, info, this);
126+
}
127+
123128
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Swish& node) {
124129
float beta = 1.0;
125130
if (node.get_input_size() > 1 && ov::get_constant_from_source(node.input_value(1)) != nullptr) {

modules/arm_plugin/src/arm_converter/arm_converter_softmax.cpp

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,22 @@
22
// SPDX-License-Identifier: Apache-2.0
33

44
#include <arm_compute/runtime/NEON/functions/NESoftmaxLayer.h>
5-
#include <ngraph/runtime/reference/softmax.hpp>
65
#include "arm_converter/arm_converter.hpp"
76

87
namespace ArmPlugin {
98
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Softmax& node) {
10-
if (true) {
11-
return MakeConversion<arm_compute::NESoftmaxLayer>(node.input(0),
12-
node.output(0),
13-
1.0f,
14-
static_cast<int32_t>(AxisCast(node.get_axis(), node.get_shape().size())));
15-
} else {
16-
auto make = [&] (auto refFunction) {
17-
return this->MakeConversion(refFunction, node.input(0), node.output(0), node.get_shape(), ngraph::AxisSet{node.get_axis()});
18-
};
19-
return CallSwitch(
20-
AP_WRAP(make, ngraph::runtime::reference::softmax),
21-
node.input(0), floatTypes);
22-
}
9+
return MakeConversion<arm_compute::NESoftmaxLayer>(node.input(0),
10+
node.output(0),
11+
1.0f,
12+
static_cast<int32_t>(AxisCast(node.get_axis(), node.get_shape().size())));
13+
}
14+
15+
template<> Converter::Conversion::Ptr Converter::Convert(const opset::LogSoftmax& node) {
16+
auto axis = node.get_axis();
17+
if (axis < 0) { axis += node.get_shape().size(); }
18+
return MakeConversion<arm_compute::NELogSoftmaxLayer>(node.input(0),
19+
node.output(0),
20+
1.0f,
21+
static_cast<int32_t>(AxisCast(axis, node.get_shape().size())));
2322
}
2423
} // namespace ArmPlugin

modules/arm_plugin/src/transformations/arm_optimizations.cpp

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,10 @@
33

44

55
#include "transformations/common_optimizations/nop_elimination.hpp"
6-
#include "transformations/common_optimizations/conv_mul_fusion.hpp"
76
#include "transformations/convert_precision.hpp"
87
#include "transformations/init_node_info.hpp"
98
#include "transformations/decompose_variadic_split.hpp"
109
#include "transformations/common_optimizations/softplus_fusion.hpp"
11-
#include "transformations/op_conversions/convert_mod.hpp"
12-
#include "transformations/op_conversions/convert_negative.hpp"
1310
#include "transformations/op_conversions/convert_reduce_to_pooling.hpp"
1411
#include "transformations/op_conversions/convert_broadcast3.hpp"
1512
#include "transformations/op_conversions/convert_broadcast_to_tiles.hpp"
@@ -18,22 +15,27 @@
1815
#include "transformations/op_conversions/rnn_cell_decomposition.hpp"
1916
#include "transformations/op_conversions/lstm_cell_decomposition.hpp"
2017
#include "transformations/op_conversions/gru_cell_decomposition.hpp"
21-
#include "transformations/common_optimizations/lin_op_sequence_fusion.hpp"
22-
#include "transformations/op_conversions/reduce_l1_decomposition.hpp"
23-
#include "transformations/op_conversions/reduce_l2_decomposition.hpp"
2418
#include "transformations/op_conversions/log_softmax_decomposition.hpp"
2519
#include "transformations/common_optimizations/remove_filtering_boxes_by_size.hpp"
2620
#include "transformations/common_optimizations/hswish_fusion.hpp"
27-
#include "transformations/op_conversions/convert_interpolate1_to_interpolate4.hpp"
2821
#include "transformations/op_conversions/convert_mvn1_to_mvn6.hpp"
2922
#include "transformations/op_conversions/convert_gelu.hpp"
3023
#include "transformations/op_conversions/convert_ti_to_sequences.hpp"
31-
#include "transformations/common_optimizations/weights_dequantize_to_fake_quantize.hpp"
3224
#include "transformations/common_optimizations/convert_quantize_dequantize.hpp"
3325
#include "transformations/op_conversions/convert_subtract.hpp"
3426
#include "transformations/op_conversions/convert_maxpool_downgrade.hpp"
3527
#include "transformations/op_conversions/convert_previous_nms_to_nms_9.hpp"
3628
#include "transformations/common_optimizations/common_optimizations.hpp"
29+
#include "transformations/common_optimizations/convert_compression_only_to_legacy.hpp"
30+
#include "transformations/op_conversions/hswish_decomposition.hpp"
31+
#include "transformations/op_conversions/convert_minimum_to_power_and_max.hpp"
32+
#include "transformations/op_conversions/convert_divide.hpp"
33+
#include "transformations/op_conversions/convert_depth_to_space.hpp"
34+
#include "transformations/op_conversions/convert_space_to_depth.hpp"
35+
#include "transformations/op_conversions/batch_norm_decomposition.hpp"
36+
#include "transformations/op_conversions/mvn6_decomposition.hpp"
37+
#include <transformations/op_conversions/normalize_l2_decomposition.hpp>
38+
#include <transformations/op_conversions/softmax_decomposition.hpp>
3739

3840
#include "conv_bias_fusion.hpp"
3941
#include "convert_eltwise.hpp"
@@ -155,7 +157,7 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
155157
// This pass must be called first in pipeline
156158
manager.register_pass<ov::pass::InitNodeInfo>();
157159
manager.register_pass<pass::StoreResultName>();
158-
manager.register_pass<ov::pass::CommonOptimizations>();
160+
159161
// Resolves dynamism (replaces NonZero), CF needed
160162
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::RemoveFilteringBoxesBySize>();
161163
manager.register_pass<ngraph::pass::ConstantFolding>();
@@ -167,29 +169,41 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
167169
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::HSwishFusion>();
168170

169171
// LinOpSequenceFusion must be executed after all decompositions
170-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::LinOpSequenceFusion>();
172+
manager.register_pass<ngraph::pass::ConstantFolding>();
173+
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToGRUSequence>();
174+
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToLSTMSequence>();
175+
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToRNNSequence>();
176+
manager.register_pass<ngraph::pass::ConstantFolding>();
171177
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::RNNCellDecomposition>();
172178
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::LSTMCellDecomposition>();
173179
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::GRUCellDecomposition>();
174-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertGELU>();
180+
181+
// Run common optimizations
182+
manager.register_pass<ov::pass::CommonOptimizations>();
183+
manager.get_pass_config()->disable<ov::pass::ConvertCompressedOnlyToLegacy>();
184+
manager.get_pass_config()->disable<ov::pass::HSwishDecomposition>();
185+
manager.get_pass_config()->disable<ov::pass::LogSoftmaxDecomposition>();
186+
#ifdef __aarch64__
187+
manager.get_pass_config()->disable<ov::pass::ConvertGELU>();
188+
#endif /* __aarch64__ */
189+
manager.get_pass_config()->disable<ov::pass::ConvertBroadcastToTiles>();
190+
manager.get_pass_config()->disable<ov::pass::ConvertMinimum>();
191+
manager.get_pass_config()->disable<ov::pass::ConvertSubtract>();
192+
manager.get_pass_config()->disable<ov::pass::ConvertDivide>();
193+
manager.get_pass_config()->disable<ov::pass::ConvertDepthToSpace>();
194+
manager.get_pass_config()->disable<ov::pass::ConvertSpaceToDepth>();
195+
manager.get_pass_config()->disable<ov::pass::BatchNormDecomposition>();
196+
// MVN6Decomposition doesn't work with ARM native ReduceMean operation
197+
manager.get_pass_config()->disable<ov::pass::MVN6Decomposition>();
198+
manager.get_pass_config()->disable<ov::pass::NormalizeL2Decomposition>();
199+
manager.get_pass_config()->disable<ov::pass::SoftmaxDecomposition>();
200+
175201
manager.register_pass<ngraph::pass::ConstantFolding>();
176202
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertConv1D>();
177203
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertGroupConv1D>();
178204
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertGroupConvolution>();
179205
manager.register_pass<ngraph::pass::ConstantFolding>();
180-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvolutionMultiplyFusion>();
181-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::GroupConvolutionMultiplyFusion>();
182-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvolutionBackpropDataMultiplyFusion>();
183-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::GroupConvolutionBackpropDataMultiplyFusion>();
184-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToGRUSequence>();
185-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToLSTMSequence>();
186-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertTensorIteratorToRNNSequence>();
187-
manager.register_pass<ngraph::pass::ConstantFolding>();
188-
189-
190-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertInterpolate1ToInterpolate4>();
191206
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertMVN1ToMVN6>();
192-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertQuantizeDequantize>();
193207
#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
194208
manager.register_pass<ov::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
195209
#endif
@@ -249,18 +263,14 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
249263
{
250264
Dump(m, "before_arm_specific_transformations");
251265
ov::pass::Manager manager;
252-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::LogSoftmaxDecomposition>();
253266
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertGRN>();
254267
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::NormalizeL2Fusion>();
255268
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::DecomposeNormalizeL2Add>();
256269
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertNormalizeL2ToArm>();
257270
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertReduceMultiAxis>();
258-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ReduceL1Decomposition>();
259-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ReduceL2Decomposition>();
260271
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertReduceMeanToPooling>();
261272
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertReduceMaxToPooling>();
262273
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertReduceSumToPooling>();
263-
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertMod>();
264274
manager.register_pass<ngraph::pass::ConstantFolding>();
265275
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::DecomposeMish>();
266276
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::BroadcastPRelu>();

0 commit comments

Comments
 (0)