Skip to content

Commit 27c5aef

Browse files
committed
[KleidiAI] Always attempt activation packing
ghstack-source-id: 52488f3 ghstack-comment-id: 3169012590 Pull Request resolved: #13232
1 parent c2ec0cd commit 27c5aef

File tree

1 file changed

+19
-18
lines changed

1 file changed

+19
-18
lines changed

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
609609
auto cvt_output_id = graph_node->output_id();
610610

611611
auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612-
assert(
613-
dtype == DataType::xnn_datatype_qdint8 ||
614-
dtype == DataType::xnn_datatype_qbint4);
615612
for (auto value : *graph->xvalues()) {
616613
if (value->xvalue_union_type() !=
617614
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -631,16 +628,21 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
631628
return false;
632629
}
633630

631+
std::vector<DataType> supported_filter_dtypes{
632+
DataType::xnn_datatype_qcint8,
633+
DataType::xnn_datatype_qcint4,
634+
DataType::xnn_datatype_qbint4};
634635
// Find if the convert output is going to the right linear node.
635636
// Assuming if we can find one valid linear node, then we can use QP8
636637
// for all the linear nodes consuming this convert output.
637638
for (auto node : *graph->xnodes()) {
638639
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639640
auto linear_node = node->xnode_union_as_XNNFullyConnected();
640641
if (linear_node->input1_id() == cvt_output_id) {
641-
if (check_dtype(
642-
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
643-
return true;
642+
for (auto supported_filter : supported_filter_dtypes) {
643+
if (check_dtype(linear_node->filter_id(), supported_filter)) {
644+
return true;
645+
}
644646
}
645647
}
646648
}
@@ -659,21 +661,20 @@ Error defineConvertNode(
659661
const fb_xnnpack::XNNGraph* flatbuffer_graph) noexcept {
660662
MAYBE_UNUSED(flatbuffer_graph);
661663
auto graph_node = node->xnode_union_as_XNNConvert();
662-
663-
int32_t flags = graph_node->flags();
664664
#ifdef ENABLE_XNNPACK_KLEIDI
665-
// This is not currently exposed at include/xnnpack.h yet once it is
666-
// we can remove this runtime logic and do this ahead-of-time
667-
#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
668-
if (isQP8(flatbuffer_graph, node)) {
669-
flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
670-
ET_LOG(
671-
Debug,
672-
"Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
673-
node->debug_handle());
674-
}
665+
// This is not currently exposed at include/xnnpack.h yet once it is
666+
// we can remove this runtime logic and do this ahead-of-time
667+
#define XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM 0x00000100;
668+
if (isQP8(flatbuffer_graph, node)) {
669+
flags |= XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM;
670+
ET_LOG(
671+
Debug,
672+
"Setting XNN_FLAG_MAYBE_PACK_FOR_QB4W_GEMM flag for convert node %i",
673+
node->debug_handle());
674+
}
675675
#endif
676676

677+
int32_t flags = graph_node->flags();
677678
xnn_status status = xnn_define_convert(
678679
subgraph_ptr,
679680
remapped_ids.at(graph_node->input_id()),

0 commit comments

Comments
 (0)