@@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
609609 auto cvt_output_id = graph_node->output_id ();
610610
611611 auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612- assert (
613- dtype == DataType::xnn_datatype_qdint8 ||
614- dtype == DataType::xnn_datatype_qbint4);
615612 for (auto value : *graph->xvalues ()) {
616613 if (value->xvalue_union_type () !=
617614 fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -631,16 +628,24 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
631628 return false ;
632629 }
633630
631+ // XNNPACK dtypes which have qp8 support.
632+ const std::vector<DataType> supported_filter_dtypes = {
633+ DataType::xnn_datatype_qbint4,
634+ DataType::xnn_datatype_qcint4,
635+ DataType::xnn_datatype_qcint8
636+ };
637+
634638 // Find if the convert output is going to the right linear node.
635639 // Assuming if we can find one valid linear node, then we can use QP8
636640 // for all the linear nodes consuming this convert output.
637641 for (auto node : *graph->xnodes ()) {
638642 if (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639643 auto linear_node = node->xnode_union_as_XNNFullyConnected ();
640644 if (linear_node->input1_id () == cvt_output_id) {
641- if (check_dtype (
642- linear_node->filter_id (), DataType::xnn_datatype_qbint4)) {
643- return true ;
645+ for (auto supported_filter_dtype : supported_filter_dtypes) {
646+ if (check_dtype (linear_node->filter_id (), supported_filter_dtype)) {
647+ return true ;
648+ }
644649 }
645650 }
646651 }
0 commit comments