@@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
609
609
auto cvt_output_id = graph_node->output_id ();
610
610
611
611
auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612
- assert (
613
- dtype == DataType::xnn_datatype_qdint8 ||
614
- dtype == DataType::xnn_datatype_qbint4);
615
612
for (auto value : *graph->xvalues ()) {
616
613
if (value->xvalue_union_type () !=
617
614
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -631,16 +628,23 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
631
628
return false ;
632
629
}
633
630
631
+ // XNNPACK dtypes which have qp8 support.
632
+ const std::vector<DataType> supported_filter_dtypes = {
633
+ DataType::xnn_datatype_qbint4,
634
+ DataType::xnn_datatype_qcint4,
635
+ DataType::xnn_datatype_qcint8};
636
+
634
637
// Find if the convert output is going to the right linear node.
635
638
// Assuming if we can find one valid linear node, then we can use QP8
636
639
// for all the linear nodes consuming this convert output.
637
640
for (auto node : *graph->xnodes ()) {
638
641
if (node->xnode_union_type () == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639
642
auto linear_node = node->xnode_union_as_XNNFullyConnected ();
640
643
if (linear_node->input1_id () == cvt_output_id) {
641
- if (check_dtype (
642
- linear_node->filter_id (), DataType::xnn_datatype_qbint4)) {
643
- return true ;
644
+ for (auto supported_filter_dtype : supported_filter_dtypes) {
645
+ if (check_dtype (linear_node->filter_id (), supported_filter_dtype)) {
646
+ return true ;
647
+ }
644
648
}
645
649
}
646
650
}
0 commit comments