Skip to content

Commit 9d70fdd

Browse files
committed
Only support int8 and quant dtypes for quant operators
1 parent 8cfa858 commit 9d70fdd

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

backends/xnnpack/partition/config/xnnpack_config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,17 @@
1010
from typing import List, Optional
1111

1212
import torch
13+
from torch.export import ExportedProgram
14+
from executorch.backends.xnnpack.utils.quant_utils import (
15+
is_quant,
16+
is_dequant,
17+
is_qparam,
18+
)
1319
from executorch.exir.backend.canonical_partitioners.config_partitioner import (
1420
format_target_name,
1521
PartitionerConfig,
1622
)
1723
from executorch.exir.backend.utils import WhyNoPartition
18-
from torch.export import ExportedProgram
1924

2025
logger = logging.getLogger(__name__)
2126
why = WhyNoPartition(logger=logger)
@@ -220,9 +225,18 @@ def _check_node_has_valid_dtype(self, node):
220225
valid_dtypes = {
221226
torch.float32,
222227
torch.float16,
223-
torch.int8,
224-
torch.qint8,
225228
}
229+
# Only allow int8 and quant dtypes for quant operations
230+
if is_quant(node) or is_dequant(node) or is_qparam(node):
231+
valid_dtypes.update(
232+
{
233+
torch.qint32,
234+
torch.qint8,
235+
torch.quint8,
236+
torch.int8,
237+
}
238+
)
239+
226240
if (
227241
node.op != "placeholder"
228242
and node.op != "call_function"

0 commit comments

Comments
 (0)