@@ -610,6 +610,9 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
610610 auto cvt_output_id = graph_node->output_id ();
611611
612612 auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
613+ assert (
614+ dtype == DataType::xnn_datatype_qdint8 ||
615+ dtype == DataType::xnn_datatype_qbint4);
613616 for (auto value : *graph->xvalues ()) {
614617 if (value->xvalue_union_type () !=
615618 fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -629,23 +632,16 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
629632 return false ;
630633 }
631634
632- // XNNPACK dtypes which have qp8 support.
633- const std::vector<DataType> supported_filter_dtypes = {
634- DataType::xnn_datatype_qbint4,
635- DataType::xnn_datatype_qcint4,
636- DataType::xnn_datatype_qcint8};
637-
638635 // Find if the convert output is going to the right linear node.
639636 // Assuming if we can find one valid linear node, then we can use QP8
640637 // for all the linear nodes consuming this convert output.
641638 for (auto node : *graph->xnodes ()) {
642639 if (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
643640 auto linear_node = node->xnode_union_as_XNNFullyConnected ();
644641 if (linear_node->input1_id () == cvt_output_id) {
645- for (auto supported_filter_dtype : supported_filter_dtypes) {
646- if (check_dtype (linear_node->filter_id (), supported_filter_dtype)) {
647- return true ;
648- }
642+ if (check_dtype (
643+ linear_node->filter_id (), DataType::xnn_datatype_qbint4)) {
644+ return true ;
649645 }
650646 }
651647 }
0 commit comments