@@ -142,6 +142,29 @@ void ArmPlugin::pass::ArmOptimizations::Dump(const std::shared_ptr<ov::Model>& m
142142 }
143143}
144144
145+ static bool fuse_type_to_convert (const std::shared_ptr<ngraph::Node>& node, ov::element::Type to, size_t idx) {
146+ if (auto convert = ov::as_type_ptr<ArmPlugin::opset::Convert>(node)) {
147+ // For Convert node, converting precision from floating point to boolean will lead to mathematical
148+ // error, because here the output precision boolean is replaced by u8. E.g. floating point value 0.01
149+ // is converted to be 1 for boolean, but 0 for u8. Thus an Abs and Ceil node should be added before the
150+ // Convert node for this scenario.
151+ if (convert->input (0 ).get_element_type ().is_real () &&
152+ convert->get_convert_element_type () == ngraph::element::boolean && to.is_integral_number ()) {
153+ auto abs = std::make_shared<ArmPlugin::opset::Abs>(convert->input_value (0 ).get_node_shared_ptr ());
154+ auto ceil = std::make_shared<ArmPlugin::opset::Ceiling>(abs);
155+ auto new_convert = std::make_shared<ArmPlugin::opset::Convert>(ceil, to);
156+ new_convert->set_friendly_name (convert->get_friendly_name ());
157+ ov::copy_runtime_info (convert, {abs, ceil, new_convert});
158+ ov::replace_node (convert, new_convert);
159+ return true ;
160+ } else {
161+ convert->set_convert_element_type (to);
162+ return true ;
163+ }
164+ }
165+ return false ;
166+ }
167+
145168bool ArmPlugin::pass::ArmOptimizations::run_on_model (const std::shared_ptr<ov::Model> &m) {
146169 auto quantized = _lpt && ngraph::pass::low_precision::LowPrecision::isFunctionQuantized (m);
147170 {
@@ -205,9 +228,6 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
205228 manager.register_pass <ov::pass::GraphRewrite>()->add_matcher <pass::ConvertGroupConvolution>();
206229 manager.register_pass <ngraph::pass::ConstantFolding>();
207230 manager.register_pass <ov::pass::GraphRewrite>()->add_matcher <ov::pass::ConvertMVN1ToMVN6>();
208- #ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
209- manager.register_pass <ov::pass::ConvertPrecision>(ngraph::element::f16 , ngraph::element::f32 );
210- #endif
211231
212232 auto pass_config = manager.get_pass_config ();
213233
@@ -261,6 +281,23 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
261281 lptManager.run_passes (m);
262282 }
263283
284+ auto get_convert_precisions = []() {
285+ precisions_array array = {
286+ {ngraph::element::i64 , ngraph::element::i32 },
287+ {ngraph::element::u64 , ngraph::element::i32 },
288+ {ngraph::element::f64 , ngraph::element::f32 },
289+ {ngraph::element::boolean, ngraph::element::u8 },
290+ {ngraph::element::i4, ngraph::element::i8 },
291+ {ngraph::element::u4, ngraph::element::u8 }
292+ };
293+ #ifndef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
294+ array.push_back ({ngraph::element::f16 , ngraph::element::f32 });
295+ #endif
296+ return array;
297+ };
298+ static const auto precisions = get_convert_precisions ();
299+ type_to_fuse_map type_to_fuse = {{ArmPlugin::opset::Convert::get_type_info_static (), fuse_type_to_convert}};
300+
264301 {
265302 Dump (m, " before_arm_specific_transformations" );
266303 ov::pass::Manager manager;
@@ -317,9 +354,7 @@ bool ArmPlugin::pass::ArmOptimizations::run_on_model(const std::shared_ptr<ov::M
317354 manager.register_pass <pass::FinalizeTrailingNodes>();
318355 manager.register_pass <pass::StoreResultName>();
319356 manager.register_pass <ngraph::pass::ConstantFolding>();
320- manager.register_pass <ov::pass::ConvertPrecision>(ngraph::element::boolean, ngraph::element::u8 );
321- manager.register_pass <ov::pass::ConvertPrecision>(ngraph::element::i64 , ngraph::element::i32 );
322- manager.register_pass <ov::pass::ConvertPrecision>(ngraph::element::u64 , ngraph::element::i32 );
357+ manager.register_pass <ov::pass::ConvertPrecision>(precisions, type_to_fuse);
323358 manager.register_pass <ov::pass::GraphRewrite>()->add_matcher <pass::AlignNodePrecision>();
324359 manager.register_pass <pass::ConvertPrecisionFP16ToFP32>();
325360 manager.register_pass <ov::pass::GraphRewrite>()->add_matcher <pass::ConvertArmConvert>();
0 commit comments