Skip to content

Commit 932818c

Browse files
authored
XNNPACK: Kleidi QP8 and SME2 (#13887)
1 parent ad19cb8 commit 932818c

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

backends/xnnpack/cmake/Dependencies.cmake

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,10 @@ set(XNNPACK_ENABLE_AVX512VNNIGFNI
4242
OFF
4343
CACHE BOOL ""
4444
)
45-
45+
set(XNNPACK_ENABLE_ARM_SME2
46+
ON
47+
CACHE BOOL ""
48+
)
4649
if(EXECUTORCH_XNNPACK_ENABLE_KLEIDI)
4750
set(XNNPACK_ENABLE_KLEIDIAI
4851
ON

backends/xnnpack/runtime/XNNCompiler.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -609,9 +609,6 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
609609
auto cvt_output_id = graph_node->output_id();
610610

611611
auto check_dtype = [graph](uint32_t id, DataType dtype) -> bool {
612-
assert(
613-
dtype == DataType::xnn_datatype_qdint8 ||
614-
dtype == DataType::xnn_datatype_qbint4);
615612
for (auto value : *graph->xvalues()) {
616613
if (value->xvalue_union_type() !=
617614
fb_xnnpack::XValueUnion::XNNQuantizedTensorValue) {
@@ -631,16 +628,23 @@ bool isQP8(const fb_xnnpack::XNNGraph* graph, const NodePtr node) {
631628
return false;
632629
}
633630

631+
// XNNPACK dtypes which have qp8 support.
632+
const std::vector<DataType> supported_filter_dtypes = {
633+
DataType::xnn_datatype_qbint4,
634+
DataType::xnn_datatype_qcint4,
635+
DataType::xnn_datatype_qcint8};
636+
634637
// Find if the convert output is going to the right linear node.
635638
// Assuming if we can find one valid linear node, then we can use QP8
636639
// for all the linear nodes consuming this convert output.
637640
for (auto node : *graph->xnodes()) {
638641
if (node->xnode_union_type() == fb_xnnpack::XNodeUnion::XNNFullyConnected) {
639642
auto linear_node = node->xnode_union_as_XNNFullyConnected();
640643
if (linear_node->input1_id() == cvt_output_id) {
641-
if (check_dtype(
642-
linear_node->filter_id(), DataType::xnn_datatype_qbint4)) {
643-
return true;
644+
for (auto supported_filter_dtype : supported_filter_dtypes) {
645+
if (check_dtype(linear_node->filter_id(), supported_filter_dtype)) {
646+
return true;
647+
}
644648
}
645649
}
646650
}

0 commit comments

Comments
 (0)