diff --git a/backends/xnnpack/runtime/XNNCompiler.cpp b/backends/xnnpack/runtime/XNNCompiler.cpp index b886bba2857..73578043f63 100644 --- a/backends/xnnpack/runtime/XNNCompiler.cpp +++ b/backends/xnnpack/runtime/XNNCompiler.cpp @@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { auto cvt_output_id = graph_node->output_id(); auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool { - assert( - dtype == DataType::xnn_datatype_qdint8 || - dtype == DataType::xnn_datatype_qbint4); for (auto value : *graph->xvalues()) { if (value->xvalue_union_type() != fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) { @@ -631,6 +628,10 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { return false; } + std::vector supported_filter_dtypes{ + DataType::xnn_datatype_qcint8, + DataType::xnn_datatype_qcint4, + DataType::xnn_datatype_qbint4}; // Find if the convert output is going to the right linear node. // Assuming if we can find one valid linear node, then we can use QP8 // for all the linear nodes consuming this convert output. @@ -638,9 +639,10 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) { if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) { auto linear_node = node->xnode_union_as_XNNFullyConnected(); if (linear_node->input1_id() == cvt_output_id) { - if (check_dtype( - linear_node->filter_id(), DataType::xnn_datatype_qbint4)) { - return true; + for (auto supported_filter : supported_filter_dtypes) { + if (check_dtype(linear_node->filter_id(), supported_filter)) { + return true; + } } } }