diff --git a/backends/xnnpack/partition/config/xnnpack_config.py b/backends/xnnpack/partition/config/xnnpack_config.py index ddbe8edc42d..817f9d1cf50 100644 --- a/backends/xnnpack/partition/config/xnnpack_config.py +++ b/backends/xnnpack/partition/config/xnnpack_config.py @@ -10,6 +10,11 @@ from typing import List, Optional import torch +from executorch.backends.xnnpack.utils.quant_utils import ( + is_dequant, + is_qparam, + is_quant, +) from executorch.exir.backend.canonical_partitioners.config_partitioner import ( format_target_name, PartitionerConfig, @@ -223,9 +228,18 @@ def _check_node_has_valid_dtype(self, node): valid_dtypes = { torch.float32, torch.float16, - torch.int8, - torch.qint8, } + # Only allow int8 and quant dtypes for quant operations + if is_quant(node) or is_dequant(node) or is_qparam(node): + valid_dtypes.update( + { + torch.qint32, + torch.qint8, + torch.quint8, + torch.int8, + } + ) + if ( node.op != "placeholder" and node.op != "call_function"