Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions backends/xnnpack/runtime/XNNCompiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
auto cvt_output_id = graph_node->output_id();

auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
assert(
dtype == DataType::xnn_datatype_qdint8 ||
dtype == DataType::xnn_datatype_qbint4);
for (auto value : *graph->xvalues()) {
if (value->xvalue_union_type() !=
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
Expand All @@ -631,16 +628,21 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
return false;
}

std::vector<DataType> supported_filter_dtypes{
DataType::xnn_datatype_qcint8,
DataType::xnn_datatype_qcint4,
DataType::xnn_datatype_qbint4};
// Find if the convert output is going to the right linear node.
// Assuming if we can find one valid linear node, then we can use QP8
// for all the linear nodes consuming this convert output.
for (auto node : *graph->xnodes()) {
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
auto linear_node = node->xnode_union_as_XNNFullyConnected();
if (linear_node->input1_id() == cvt_output_id) {
if (check_dtype(
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
return true;
for (auto supported_filter : supported_filter_dtypes) {
if (check_dtype(linear_node->filter_id(), supported_filter)) {
return true;
}
}
}
}
Expand Down
Loading