@@ -118,10 +118,9 @@ xnn_datatype getDataType(const DataType& data_type) {
118118 return xnn_datatype::xnn_datatype_qcint32;
119119 case DataType::xnn_datatype_qcint4:
120120 return xnn_datatype::xnn_datatype_qcint4;
121- case DataType::xnn_datatype_qdint8:
122- return xnn_datatype::xnn_datatype_qdint8;
123121 case DataType::xnn_datatype_qbint4:
124122 return xnn_datatype::xnn_datatype_qbint4;
123+ case DataType::xnn_datatype_qdint8: // always try to us kleidi
125124 case DataType::xnn_datatype_qpint8:
126125 return xnn_datatype::xnn_datatype_qpint8;
127126 case DataType::xnn_datatype_int32:
@@ -600,54 +599,6 @@ Error defineTensor(
600599 return Error::Ok;
601600};
602601
603- #define MAYBE_UNUSED (x ) (void )(x)
604-
605- #ifdef ENABLE_XNNPACK_KLEIDI
606- bool isQP8 (const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
607- assert (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNConvert);
608- auto graph_node = node->xnode_union_as_XNNConvert ();
609- auto cvt_output_id = graph_node->output_id ();
610-
611- auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612- assert (
613- dtype == DataType::xnn_datatype_qdint8 ||
614- dtype == DataType::xnn_datatype_qbint4);
615- for (auto value : *graph->xvalues ()) {
616- if (value->xvalue_union_type () !=
617- fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
618- continue ;
619- }
620- auto tensor =
621- value->xvalue_union_as_XNNQuantizedTensorValue ()->tensor_value ();
622- if (tensor->id_out () == id) {
623- return tensor->datatype () == dtype;
624- }
625- }
626- return false ;
627- };
628-
629- // Check if the output tensor is qint8 else bail early.
630- if (!check_dtype (cvt_output_id, DataType::xnn_datatype_qdint8)) {
631- return false ;
632- }
633-
634- // Find if the convert output is going to the right linear node.
635- // Assuming if we can find one valid linear node, then we can use QP8
636- // for all the linear nodes consuming this convert output.
637- for (auto node : *graph->xnodes ()) {
638- if (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639- auto linear_node = node->xnode_union_as_XNNFullyConnected ();
640- if (linear_node->input1_id () == cvt_output_id) {
641- if (check_dtype (
642- linear_node->filter_id (), DataType::xnn_datatype_qbint4)) {
643- return true ;
644- }
645- }
646- }
647- }
648- return false ;
649- }
650- #endif // ENABLE_XNNPACK_KLEIDI
651602
652603/*
653604Define Convert operator Node into the subgraph
@@ -661,19 +612,6 @@ Error defineConvertNode(
661612 auto graph_node = node->xnode_union_as_XNNConvert ();
662613
663614 int32_t flags = graph_node->flags ();
664- #ifdef ENABLE_XNNPACK_KLEIDI
665- // This is not currently exposed at include/xnnpack.h yet once it is
666- // we can remove this runtime logic and do this ahead-of-time
667- #define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100 ;
668- if (isQP8 (flatbuffer_graph, node)) {
669- flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
670- ET_LOG (
671- Debug,
672- " Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i" ,
673- node->debug_handle ());
674- }
675- #endif
676-
677615 xnn_status status = xnn_define_convert (
678616 subgraph_ptr,
679617 remapped_ids.at (graph_node->input_id ()),
0 commit comments