Skip to content

Commit ae1f96d

Browse files
authored
[ARM plugin] Activation fixes (#490)
1 parent fa868c8 commit ae1f96d

File tree

3 files changed

+45
-7
lines changed

3 files changed

+45
-7
lines changed

modules/arm_plugin/src/arm_converter/arm_converter_activation.cpp

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
#include <arm_compute/runtime/NEON/functions/NEElementwiseUnaryLayer.h>
77
#include <arm_compute/runtime/NEON/functions/NEFloor.h>
88
#include <arm_compute/runtime/NEON/functions/NEPReluLayer.h>
9+
#include <ngraph/runtime/reference/abs.hpp>
10+
#include <ngraph/runtime/reference/clamp.hpp>
11+
#include <ngraph/runtime/reference/floor.hpp>
912
#include <ngraph/runtime/reference/hsigmoid.hpp>
1013
#include <ngraph/runtime/reference/hard_sigmoid.hpp>
1114
#include <ngraph/runtime/reference/selu.hpp>
@@ -40,13 +43,34 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::PRelu& nod
4043
}
4144

4245
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Abs& node) {
43-
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::ABS);
44-
return ConvertActivation(node, info, this);
46+
if (node.input(0).get_element_type() == ngraph::element::f32 ||
47+
node.input(0).get_element_type() == ngraph::element::f16) {
48+
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::ABS);
49+
return ConvertActivation(node, info, this);
50+
} else {
51+
auto make = [&] (auto refFunction) {
52+
return this->MakeConversion(refFunction, node.input(0), node.output(0), ngraph::shape_size(node.get_output_shape(0)));
53+
};
54+
return CallSwitch(
55+
AP_WRAP(make, ngraph::runtime::reference::abs),
56+
node.input(0), intTypes);
57+
}
4558
}
4659

4760
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Clamp& node) {
48-
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, node.get_max(), node.get_min());
49-
return ConvertActivation(node, info, this);
61+
if (node.input(0).get_element_type() == ngraph::element::f32 ||
62+
node.input(0).get_element_type() == ngraph::element::f16) {
63+
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, node.get_max(), node.get_min());
64+
return ConvertActivation(node, info, this);
65+
} else {
66+
auto make = [&] (auto refFunction) {
67+
return this->MakeConversion(refFunction, node.input(0), node.output(0),
68+
static_cast<std::int32_t>(node.get_min()), static_cast<std::int32_t>(node.get_max()), ngraph::shape_size(node.get_input_shape(0)));
69+
};
70+
return CallSwitch(
71+
AP_WRAP(make, ngraph::runtime::reference::clamp),
72+
node.input(0), std::tuple<std::int32_t>{});
73+
}
5074
}
5175

5276
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Sqrt& node) {
@@ -68,7 +92,17 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Exp& node)
6892
}
6993

7094
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Floor& node) {
71-
return MakeConversion<arm_compute::NEFloor>(node.input(0), node.output(0));
95+
if (node.input(0).get_element_type() == ngraph::element::f32 ||
96+
node.input(0).get_element_type() == ngraph::element::f16) {
97+
return MakeConversion<arm_compute::NEFloor>(node.input(0), node.output(0));
98+
} else {
99+
auto make = [&] (auto refFunction) {
100+
return this->MakeConversion(refFunction, node.input(0), node.output(0), ngraph::shape_size(node.get_output_shape(0)));
101+
};
102+
return CallSwitch(
103+
AP_WRAP(make, ngraph::runtime::reference::floor),
104+
node.input(0), allTypes);
105+
}
72106
}
73107

74108
template<> Converter::Conversion::Ptr Converter::Convert(const opset::HSwish& node) {
@@ -88,7 +122,7 @@ template<> Converter::Conversion::Ptr Converter::Convert(const opset::Gelu& node
88122

89123
template<> Converter::Conversion::Ptr Converter::Convert(const opset::Swish& node) {
90124
float beta = 1.0;
91-
if (ov::get_constant_from_source(node.input_value(1)) != nullptr) {
125+
if (node.get_input_size() > 1 && ov::get_constant_from_source(node.input_value(1)) != nullptr) {
92126
beta = ov::get_constant_from_source(node.input_value(1))->cast_vector<float>()[0];
93127
}
94128
arm_compute::ActivationLayerInfo info(arm_compute::ActivationLayerInfo::ActivationFunction::SWISH, beta);

modules/arm_plugin/src/arm_converter/arm_converter_convert.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ template <> Converter::Conversion::Ptr Converter::Convert(const opset::Convert&
131131
return make(ngraph::runtime::reference::convert<float, std::uint16_t>);
132132
case ngraph::element::Type_t::u32 :
133133
return make(ngraph::runtime::reference::convert<float, std::uint16_t>);
134+
case ngraph::element::Type_t::i8 :
135+
return make(ngraph::runtime::reference::convert<float, std::int8_t>);
134136
case ngraph::element::Type_t::i16 :
135137
return make(ngraph::runtime::reference::convert<float, std::int16_t>);
136138
default:

modules/arm_plugin/src/arm_converter/arm_converter_pool.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,9 @@ static void FillLayerInfo(const Pool& node, arm_compute::PoolingLayerInfo& pool_
3636
}
3737

3838
template<> Converter::Conversion::Ptr Converter::Convert(const opset::MaxPool& node) {
39-
if (node.get_input_shape(0).size() == 4) {
39+
if (node.get_input_shape(0).size() == 4 &&
40+
(node.input(0).get_element_type() == ngraph::element::f32 ||
41+
node.input(0).get_element_type() == ngraph::element::f16)) {
4042
arm_compute::PoolingLayerInfo pool_info;
4143
FillLayerInfo(node, pool_info);
4244
pool_info.pool_type = arm_compute::PoolingType::MAX;

0 commit comments

Comments
 (0)