66# pyre-unsafe
77
88import logging
9- import operator
109import os
11- from typing import Callable , cast , final , List , Optional , Tuple
10+ from typing import Callable , final , List , Optional , Tuple
1211
1312import torch
1413from executorch .backends .arm .arm_backend import ArmBackend # usort: skip
1514from executorch .backends .arm ._passes .tag_io_quant_pass import TagIOQuantPass
15+ from executorch .backends .arm .operator_support .tosa_supported_operators import (
16+ TOSASupportedOperators ,
17+ )
18+ from executorch .backends .arm .tosa_specification import TosaSpecification
1619from executorch .exir .backend .compile_spec_schema import CompileSpec
1720from executorch .exir .backend .partitioner import (
1821 DelegationSpec ,
1922 Partitioner ,
2023 PartitionResult ,
2124)
2225from executorch .exir .backend .utils import tag_constant_data
23- from executorch .exir .dialects ._ops import ops as exir_ops
2426from executorch .exir .passes import PassManager
2527from torch .export .exported_program import ExportedProgram
2628from torch .fx .passes .infra .partitioner import CapabilityBasedPartitioner
2729
28- from torch .fx .passes .operator_support import OperatorSupportBase
29-
3030logger = logging .getLogger (__name__ )
3131logger .setLevel (logging .WARNING )
3232TOSA_DBG_VERBOSE = os .environ .get ("TOSA_DBG_VERBOSE" ) == "1"
3535 logger .setLevel (logging .INFO )
3636
3737
38- class TOSASupportedOperators (OperatorSupportBase ):
39- def is_node_supported (self , submodules , node : torch .fx .Node ) -> bool :
40- supported = node .op == "call_function" and node .target in [
41- exir_ops .edge .aten .add .Tensor ,
42- exir_ops .edge .aten .expand_copy .default ,
43- exir_ops .edge .aten .cat .default ,
44- exir_ops .edge .aten .bmm .default ,
45- exir_ops .edge .aten .permute_copy .default ,
46- exir_ops .edge .aten .hardtanh .default ,
47- exir_ops .edge .aten .convolution .default ,
48- exir_ops .edge .aten .div .Tensor ,
49- exir_ops .edge .aten .exp .default ,
50- exir_ops .edge .aten .log .default ,
51- exir_ops .edge .aten .linear .default ,
52- exir_ops .edge .aten .split_with_sizes_copy .default ,
53- exir_ops .edge .aten .full .default ,
54- exir_ops .edge .aten .mul .Tensor ,
55- exir_ops .edge .aten ._native_batch_norm_legit_no_training .default ,
56- exir_ops .edge .aten .native_layer_norm .default ,
57- exir_ops .edge .aten .avg_pool2d .default ,
58- exir_ops .edge .aten .max_pool2d_with_indices .default ,
59- exir_ops .edge .aten .sigmoid .default ,
60- exir_ops .edge .aten .mm .default ,
61- exir_ops .edge .aten .repeat .default ,
62- exir_ops .edge .aten .reciprocal .default ,
63- exir_ops .edge .aten .relu .default ,
64- exir_ops .edge .aten .rsqrt .default ,
65- exir_ops .edge .aten ._softmax .default ,
66- exir_ops .edge .aten .select_copy .int ,
67- exir_ops .edge .aten ._log_softmax .default ,
68- exir_ops .edge .aten .slice_copy .Tensor ,
69- exir_ops .edge .aten .sub .Tensor ,
70- exir_ops .edge .aten .sum .dim_IntList ,
71- exir_ops .edge .aten .tanh .default ,
72- exir_ops .edge .aten .upsample_nearest2d .vec ,
73- exir_ops .edge .aten .view_copy .default ,
74- exir_ops .edge .aten .clone .default ,
75- exir_ops .edge .aten .mean .dim ,
76- exir_ops .edge .aten .var .correction ,
77- exir_ops .edge .aten .unsqueeze_copy .default ,
78- exir_ops .edge .aten .squeeze_copy .dims ,
79- operator .getitem ,
80- exir_ops .edge .quantized_decomposed .quantize_per_tensor .default ,
81- exir_ops .edge .quantized_decomposed .dequantize_per_tensor .default ,
82- ]
83-
84- supported &= self .is_node_supported_custom (node )
85-
86- # Override partitioning based on pre partition passes
87- if "arm_override_partition" in node .meta :
88- supported = supported & node .meta ["arm_override_partition" ]
89- node .meta .pop ("arm_override_partition" )
90-
91- return supported
92-
93- def is_node_supported_custom (self , node : torch .fx .Node ) -> bool :
94- if node .target == exir_ops .edge .aten .mean .dim :
95- keep_dim = node .args [2 ] if len (node .args ) > 2 else False
96- return cast (bool , keep_dim )
97- if node .target == exir_ops .edge .aten .var .correction :
98- keep_dim = node .kwargs .get ("keepdim" , False )
99- return cast (bool , keep_dim )
100- return True
101-
102-
10338@final
10439class ArmPartitioner (Partitioner ):
10540 def __init__ (self , compile_spec : List [CompileSpec ]) -> None :
@@ -111,6 +46,12 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
11146 logger .info ("ArmPartitioner::partition" )
11247 partition_tags = {}
11348
49+ tosa_spec = TosaSpecification .create_from_compilespecs (
50+ self .delegation_spec .compile_specs
51+ )
52+
53+ logger .info (f"Partitioning for { tosa_spec } " )
54+
11455 for spec in self .delegation_spec .compile_specs :
11556 if spec .key == "quantize_io" and spec .value .decode () == "True" :
11657 # Exclude IO quantization from the partition
@@ -123,7 +64,7 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult:
12364
12465 capability_partitioner = CapabilityBasedPartitioner (
12566 exported_program .graph_module ,
126- TOSASupportedOperators (),
67+ TOSASupportedOperators (tosa_spec ),
12768 allows_single_node_partition = True ,
12869 )
12970 partition_list = capability_partitioner .propose_partitions ()
0 commit comments