1- # Copyright 2023-2024 Arm Limited and/or its affiliates.
1+ # Copyright 2023-2025 Arm Limited and/or its affiliates.
22#
33# This source code is licensed under the BSD-style license found in the
44# LICENSE file in the root directory of this source tree.
1010from typing import Callable , final , List , Optional , Tuple
1111
1212import torch
13- from executorch .backends .arm .arm_backend import ArmBackend # usort: skip
14- from executorch .backends .arm ._passes .tag_io_quant_pass import TagIOQuantPass
13+ from executorch .backends .arm .arm_backend import (
14+ ArmBackend ,
15+ is_quantize_io ,
16+ ) # usort: skip
1517from executorch .backends .arm .operator_support .tosa_supported_operators import (
1618 TOSASupportedOperators ,
1719)
2325 PartitionResult ,
2426)
2527from executorch .exir .backend .utils import tag_constant_data
26- from executorch .exir .passes import PassManager
28+ from executorch .exir .dialects . _ops import ops as exir_ops
2729from torch .export .exported_program import ExportedProgram
2830from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
2931
3537 logger .setLevel (logging .INFO )
3638
3739
40+ def is_quant_node (node : torch .fx .node .Node ) -> bool :
41+ return node .target in {
42+ exir_ops .edge .quantized_decomposed .quantize_per_channel .default ,
43+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
44+ exir_ops .edge .quantized_decomposed .quantize_per_tensor .tensor ,
45+ }
46+
47+
48+ def is_dequant_node (node : torch .fx .node .Node ) -> bool :
49+ return node .target in {
50+ exir_ops .edge .quantized_decomposed .dequantize_per_channel .default ,
51+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
52+ exir_ops .edge .quantized_decomposed .dequantize_per_tensor .tensor ,
53+ }
54+
55+
3856@final
3957class ArmPartitioner (Partitioner ):
4058 def __init__ (self , compile_spec : List [CompileSpec ]) -> None :
@@ -43,6 +61,7 @@ def __init__(self, compile_spec: List[CompileSpec]) -> None:
4361 def partition (self , exported_program : ExportedProgram ) -> PartitionResult :
4462 # Run the CapabilityBasedPartitioner to return the largest possible
4563 # subgraphs containing the nodes with the tags
64+
4665 logger .info ("ArmPartitioner::partition" )
4766 partition_tags = {}
4867
@@ -52,28 +71,42 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
5271
5372 logger .info (f"Partitioning for { tosa_spec } " )
5473
55- for spec in self .delegation_spec .compile_specs :
56- if spec .key == "quantize_io" and spec .value .decode () == "True" :
57- # Exclude IO quantization from the partition
58- passes = PassManager (
59- passes = [
60- TagIOQuantPass (),
61- ]
62- )
63- passes (exported_program .graph_module )
64-
6574 capability_partitioner = CapabilityBasedPartitioner (
6675 exported_program .graph_module ,
6776 TOSASupportedOperators (tosa_spec ),
6877 allows_single_node_partition = True ,
6978 )
7079 partition_list = capability_partitioner .propose_partitions ()
7180 for partition in partition_list :
81+ tag = f"tag{ partition .id } "
82+
83+ def is_partitioned (node : torch .fx .Node , tag = tag ) -> bool :
84+ return (
85+ "delegation_tag" in node .meta and node .meta ["delegation_tag" ] == tag
86+ )
87+
7288 for node in partition .nodes :
73- tag = f"tag{ partition .id } "
7489 node .meta ["delegation_tag" ] = tag
7590 partition_tags [tag ] = self .delegation_spec
7691
92+ if not is_quantize_io (self .delegation_spec .compile_specs ):
93+ continue
94+
95+ # De-tag outmost q-nodes upwards and dq-nodes downwards.
96+ # De-tag if at least one input/ output is not part of partition.
97+ for node in partition .nodes :
98+ if is_quant_node (node ):
99+ for input in node .all_input_nodes :
100+ if not is_partitioned (input ):
101+ del node .meta ["delegation_tag" ]
102+ break
103+
104+ if is_dequant_node (node ):
105+ for user in node .users :
106+ if not is_partitioned (user ):
107+ del node .meta ["delegation_tag" ]
108+ break
109+
77110 tag_constant_data (exported_program )
78111
79112 return PartitionResult (
0 commit comments