File tree Expand file tree Collapse file tree 1 file changed +17
-3
lines changed
backends/xnnpack/partition/config Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Original file line number Diff line number Diff line change 1010from typing import List , Optional
1111
1212import torch
13+ from torch .export import ExportedProgram
14+ from executorch .backends .xnnpack .utils .quant_utils import (
15+ is_quant ,
16+ is_dequant ,
17+ is_qparam ,
18+ )
1319from executorch .exir .backend .canonical_partitioners .config_partitioner import (
1420 format_target_name ,
1521 PartitionerConfig ,
1622)
1723from executorch .exir .backend .utils import WhyNoPartition
18- from torch .export import ExportedProgram
1924
2025logger = logging .getLogger (__name__ )
2126why = WhyNoPartition (logger = logger )
@@ -220,9 +225,18 @@ def _check_node_has_valid_dtype(self, node):
220225 valid_dtypes = {
221226 torch .float32 ,
222227 torch .float16 ,
223- torch .int8 ,
224- torch .qint8 ,
225228 }
229+ # Only allow int8 and quant dtypes for quant operations
230+ if is_quant (node ) or is_dequant (node ) or is_qparam (node ):
231+ valid_dtypes .update (
232+ {
233+ torch .qint32 ,
234+ torch .qint8 ,
235+ torch .quint8 ,
236+ torch .int8 ,
237+ }
238+ )
239+
226240 if (
227241 node .op != "placeholder"
228242 and node .op != "call_function"
You can’t perform that action at this time.
0 commit comments