@@ -118,10 +118,12 @@ xnn_datatype getDataType(const DataType& data_type) {
118118 return xnn_datatype::xnn_datatype_qcint32;
119119 case DataType::xnn_datatype_qcint4:
120120 return xnn_datatype::xnn_datatype_qcint4;
121- case DataType::xnn_datatype_qdint8:
122- return xnn_datatype::xnn_datatype_qdint8;
123121 case DataType::xnn_datatype_qbint4:
124122 return xnn_datatype::xnn_datatype_qbint4;
123+ case DataType::xnn_datatype_qdint8:
124+ #if !defined(ENABLE_XNNPACK_KLEIDI) || ENABLE_XNNPACK_KLEIDI == 0
125+ return xnn_datatype::xnn_datatype_qdint8;
126+ #endif
125127 case DataType::xnn_datatype_qpint8:
126128 return xnn_datatype::xnn_datatype_qpint8;
127129 case DataType::xnn_datatype_int32:
@@ -609,9 +611,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
609611 auto cvt_output_id = graph_node->output_id ();
610612
611613 auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612- assert (
613- dtype == DataType::xnn_datatype_qdint8 ||
614- dtype == DataType::xnn_datatype_qbint4);
615614 for (auto value : *graph->xvalues ()) {
616615 if (value->xvalue_union_type () !=
617616 fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -631,16 +630,21 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
631630 return false ;
632631 }
633632
633+ std::vector<DataType> supported_filter_dtypes{
634+ DataType::xnn_datatype_qcint8,
635+ DataType::xnn_datatype_qcint4,
636+ DataType::xnn_datatype_qbint4};
634637 // Find if the convert output is going to the right linear node.
635638 // Assuming if we can find one valid linear node, then we can use QP8
636639 // for all the linear nodes consuming this convert output.
637640 for (auto node : *graph->xnodes ()) {
638641 if (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639642 auto linear_node = node->xnode_union_as_XNNFullyConnected ();
640643 if (linear_node->input1_id () == cvt_output_id) {
641- if (check_dtype (
642- linear_node->filter_id (), DataType::xnn_datatype_qbint4)) {
643- return true ;
644+ for (auto supported_filter : supported_filter_dtypes) {
645+ if (check_dtype (linear_node->filter_id (), supported_filter)) {
646+ return true ;
647+ }
644648 }
645649 }
646650 }
@@ -661,19 +665,6 @@ Error defineConvertNode(
661665 auto graph_node = node->xnode_union_as_XNNConvert ();
662666
663667 int32_t flags = graph_node->flags ();
664- #ifdef ENABLE_XNNPACK_KLEIDI
665- // This is not currently exposed at include/xnnpack.h yet once it is
666- // we can remove this runtime logic and do this ahead-of-time
667- #define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100 ;
668- if (isQP8 (flatbuffer_graph, node)) {
669- flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
670- ET_LOG (
671- Debug,
672- " Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i" ,
673- node->debug_handle ());
674- }
675- #endif
676-
677668 xnn_status status = xnn_define_convert (
678669 subgraph_ptr,
679670 remapped_ids.at (graph_node->input_id ()),
0 commit comments