Skip to content

Commit 55feb6c

Browse files
authored
[ARM CPU plugin] Ensure mathematical correctness of converting floating point to boolean (#497)
1 parent 6311e35 commit 55feb6c

File tree

1 file changed

+41
-6
lines changed

1 file changed

+41
-6
lines changed

modules/arm_plugin/src/transformations/arm_optimizations.cpp

Lines changed: 41 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,29 @@ void ArmPlugin::pass::ArmOptimizations::Dump(const std::shared_ptr<ov::Model>& m
142142
}
143143
}
144144

145+
static bool fuse_type_to_convert(const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
146+
if (auto convert = ov::as_type_ptr<ArmPlugin::opset::Convert>(node)) {
147+
// For Convert node, converting precision from floating point to boolean will lead to mathematical
148+
// error, because here the output precision boolean is replaced by u8. E.g. floating point value 0.01
149+
// is converted to be 1 for boolean, but 0 for u8. Thus an Abs and Ceil node should be added before the
150+
// Convert node for this scenario.
151+
if (convert->input(0).get_element_type().is_real() &&
152+
convert->get_convert_element_type() == ngraph::element::boolean && to.is_integral_number()) {
153+
auto abs = std::make_shared<ArmPlugin::opset::Abs>(convert->input_value(0).get_node_shared_ptr());
154+
auto ceil = std::make_shared<ArmPlugin::opset::Ceiling>(abs);
155+
auto new_convert = std::make_shared<ArmPlugin::opset::Convert>(ceil, to);
156+
new_convert->set_friendly_name(convert->get_friendly_name());
157+
ov::copy_runtime_info(convert, {abs, ceil, new_convert});
158+
ov::replace_node(convert, new_convert);
159+
return true;
160+
} else {
161+
convert->set_convert_element_type(to);
162+
return true;
163+
}
164+
}
165+
return false;
166+
}
167+
145168
bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::Model> &m) {
146169
auto quantized = _lpt && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized(m);
147170
{
@@ -205,9 +228,6 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
205228
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertGroupConvolution>();
206229
manager.register_pass<ngraph::pass::ConstantFolding>();
207230
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<ov::pass::ConvertMVN1ToMVN6>();
208-
#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
209-
manager.register_pass<ov::pass::ConvertPrecision>(ngraph::element::f16, ngraph::element::f32);
210-
#endif
211231

212232
auto pass_config = manager.get_pass_config();
213233

@@ -261,6 +281,23 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
261281
lptManager.run_passes(m);
262282
}
263283

284+
auto get_convert_precisions = []() {
285+
precisions_array array = {
286+
{ngraph::element::i64, ngraph::element::i32},
287+
{ngraph::element::u64, ngraph::element::i32},
288+
{ngraph::element::f64, ngraph::element::f32},
289+
{ngraph::element::boolean, ngraph::element::u8},
290+
{ngraph::element::i4, ngraph::element::i8},
291+
{ngraph::element::u4, ngraph::element::u8}
292+
};
293+
#ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
294+
array.push_back({ngraph::element::f16, ngraph::element::f32});
295+
#endif
296+
return array;
297+
};
298+
static const auto precisions = get_convert_precisions();
299+
type_to_fuse_map type_to_fuse = {{ArmPlugin::opset::Convert::get_type_info_static(), fuse_type_to_convert}};
300+
264301
{
265302
Dump(m, "before_arm_specific_transformations");
266303
ov::pass::Manager manager;
@@ -317,9 +354,7 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
317354
manager.register_pass<pass::FinalizeTrailingNodes>();
318355
manager.register_pass<pass::StoreResultName>();
319356
manager.register_pass<ngraph::pass::ConstantFolding>();
320-
manager.register_pass<ov::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8);
321-
manager.register_pass<ov::pass::ConvertPrecision>(ngraph::element::i64, ngraph::element::i32);
322-
manager.register_pass<ov::pass::ConvertPrecision>(ngraph::element::u64, ngraph::element::i32);
357+
manager.register_pass<ov::pass::ConvertPrecision>(precisions, type_to_fuse);
323358
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::AlignNodePrecision>();
324359
manager.register_pass<pass::ConvertPrecisionFP16ToFP32>();
325360
manager.register_pass<ov::pass::GraphRewrite>()->add_matcher<pass::ConvertArmConvert>();

0 commit comments

Comments
 (0)