Skip to content

Commit 09ad09d

Browse files
committed
[KleidiAI] Always attempt activation packing
ghstack-source-id: c33a625 ghstack-comment-id: 3169012590 Pull Request resolved: #13232
1 parent 0c1acb3 commit 09ad09d

File tree

1 file changed

+4
-62
lines changed

1 file changed

+4
-62
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 4 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,12 @@ xnn_datatype getDataType(const DataType& data_type) {
118118
return xnn_datatype::xnn_datatype_qcint32;
119119
case DataType::xnn_datatype_qcint4:
120120
return xnn_datatype::xnn_datatype_qcint4;
121-
case DataType::xnn_datatype_qdint8:
122-
return xnn_datatype::xnn_datatype_qdint8;
123121
case DataType::xnn_datatype_qbint4:
124122
return xnn_datatype::xnn_datatype_qbint4;
123+
case DataType::xnn_datatype_qdint8:
124+
#if !defined(ENABLE_XNNPACK_KLEIDI) || ENABLE_XNNPACK_KLEIDI == 0
125+
return xnn_datatype::xnn_datatype_qdint8;
126+
#endif
125127
case DataType::xnn_datatype_qpint8:
126128
return xnn_datatype::xnn_datatype_qpint8;
127129
case DataType::xnn_datatype_int32:
@@ -602,53 +604,6 @@ Error defineTensor(
602604

603605
#define MAYBE_UNUSED(x) (void)(x)
604606

605-
#ifdef ENABLE_XNNPACK_KLEIDI
606-
bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
607-
assert(node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNConvert);
608-
auto graph_node = node->xnode_union_as_XNNConvert();
609-
auto cvt_output_id = graph_node->output_id();
610-
611-
auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612-
assert(
613-
dtype == DataType::xnn_datatype_qdint8 ||
614-
dtype == DataType::xnn_datatype_qbint4);
615-
for (auto value : *graph->xvalues()) {
616-
if (value->xvalue_union_type() !=
617-
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
618-
continue;
619-
}
620-
auto tensor =
621-
value->xvalue_union_as_XNNQuantizedTensorValue()->tensor_value();
622-
if (tensor->id_out() == id) {
623-
return tensor->datatype() == dtype;
624-
}
625-
}
626-
return false;
627-
};
628-
629-
// Check if the output tensor is qint8 else bail early.
630-
if (!check_dtype(cvt_output_id, DataType::xnn_datatype_qdint8)) {
631-
return false;
632-
}
633-
634-
// Find if the convert output is going to the right linear node.
635-
// Assuming if we can find one valid linear node, then we can use QP8
636-
// for all the linear nodes consuming this convert output.
637-
for (auto node : *graph->xnodes()) {
638-
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639-
auto linear_node = node->xnode_union_as_XNNFullyConnected();
640-
if (linear_node->input1_id() == cvt_output_id) {
641-
if (check_dtype(
642-
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
643-
return true;
644-
}
645-
}
646-
}
647-
}
648-
return false;
649-
}
650-
#endif // ENABLE_XNNPACK_KLEIDI
651-
652607
/*
653608
Define Convert operator Node into the subgraph
654609
*/
@@ -661,19 +616,6 @@ Error defineConvertNode(
661616
auto graph_node = node->xnode_union_as_XNNConvert();
662617

663618
int32_t flags = graph_node->flags();
664-
#ifdef ENABLE_XNNPACK_KLEIDI
665-
// This is not currently exposed at include/xnnpack.h yet once it is
666-
// we can remove this runtime logic and do this ahead-of-time
667-
#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
668-
if (isQP8(flatbuffer_graph, node)) {
669-
flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
670-
ET_LOG(
671-
Debug,
672-
"Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
673-
node->debug_handle());
674-
}
675-
#endif
676-
677619
xnn_status status = xnn_define_convert(
678620
subgraph_ptr,
679621
remapped_ids.at(graph_node->input_id()),

0 commit comments

Comments
 (0)