2828
2929import torch .fx
3030import torch .utils ._pytree as pytree
31-
3231from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
3332
3433from executorch .backends .arm .common .arm_compile_spec import ArmCompileSpec
35- from executorch .backends .arm .ethosu import EthosUCompileSpec , EthosUPartitioner
36- from executorch .backends .arm .quantizer import (
37- EthosUQuantizer ,
38- get_symmetric_quantization_config ,
39- TOSAQuantizer ,
40- VgfQuantizer ,
41- )
34+ from executorch .backends .arm .ethosu import EthosUCompileSpec
35+ from executorch .backends .arm .quantizer import get_symmetric_quantization_config
4236from executorch .backends .arm .test .runner_utils import (
4337 dbg_tosa_fb_to_json ,
4438 get_output_quantization_params ,
5347from executorch .backends .arm .tosa import TosaSpecification
5448from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
5549from executorch .backends .arm .tosa .mapping import extract_tensor_meta
56- from executorch .backends .arm .tosa .partitioner import TOSAPartitioner
5750
58- from executorch .backends .arm .vgf import VgfCompileSpec , VgfPartitioner
51+ from executorch .backends .arm .util ._factory import (
52+ create_partitioner ,
53+ create_quantizer ,
54+ parse_compile_spec ,
55+ )
56+ from executorch .backends .arm .vgf import VgfCompileSpec
5957
6058from executorch .backends .test .harness .stages import Stage , StageType
6159from executorch .backends .xnnpack .test .tester import Tester
8280 _copy_module ,
8381 _update_exported_program_graph_module ,
8482)
85-
8683from tabulate import tabulate
8784
8885from torch .export .graph_signature import ExportGraphSignature , InputSpec , OutputSpec
@@ -102,26 +99,20 @@ def _dump_lowered_modules_artifact(
10299 artifact .exported_program ().graph_signature
103100 )
104101
105- def get_output_format (lowered_module ) -> str | None :
106- for spec in lowered_module .compile_specs :
107- if spec .key == "output_format" :
108- return spec .value .decode ()
109- return None
110-
111102 for node in graph_module .graph .nodes :
112103 if node .op == "get_attr" and node .name .startswith ("lowered_module_" ):
113104 lowered_module = getattr (graph_module , node .name )
114105 assert isinstance (
115106 lowered_module , LoweredBackendModule
116107 ), f"Attribute { node .name } must be of type LoweredBackendModule."
117108
118- output_format = get_output_format (lowered_module )
119- if output_format == "tosa" :
109+ compile_spec = parse_compile_spec (lowered_module . compile_specs )
110+ if isinstance ( compile_spec , TosaCompileSpec ) :
120111 tosa_fb = lowered_module .processed_bytes
121112 to_print = dbg_tosa_fb_to_json (tosa_fb )
122113 to_print = pformat (to_print , compact = True , indent = 1 )
123114 output += f"\n TOSA deserialized { node .name } : \n { to_print } \n "
124- elif output_format == EthosUCompileSpec . get_output_format ( ):
115+ elif isinstance ( compile_spec , EthosUCompileSpec ):
125116 vela_cmd_stream = lowered_module .processed_bytes
126117 output += f"\n Vela command stream { node .name } : \n { vela_cmd_stream } \n "
127118 else :
@@ -283,13 +274,7 @@ def quantize(
283274 quantize_stage : Optional [tester .Quantize ] = None ,
284275 ):
285276 if quantize_stage is None :
286- quantizer = None
287- if isinstance (self .compile_spec , TosaCompileSpec ):
288- quantizer = TOSAQuantizer (self .compile_spec )
289- elif isinstance (self .compile_spec , EthosUCompileSpec ):
290- quantizer = EthosUQuantizer (self .compile_spec )
291- elif isinstance (self .compile_spec , VgfCompileSpec ):
292- quantizer = VgfQuantizer (self .compile_spec )
277+ quantizer = create_quantizer (self .compile_spec )
293278 quantize_stage = tester .Quantize (
294279 quantizer ,
295280 get_symmetric_quantization_config (),
@@ -311,14 +296,7 @@ def to_edge(
311296
312297 def partition (self , partition_stage : Optional [Partition ] = None ):
313298 if partition_stage is None :
314- if isinstance (self .compile_spec , TosaCompileSpec ):
315- arm_partitioner = TOSAPartitioner (self .compile_spec )
316- elif isinstance (self .compile_spec , EthosUCompileSpec ):
317- arm_partitioner = EthosUPartitioner (self .compile_spec )
318- elif isinstance (self .compile_spec , VgfCompileSpec ):
319- arm_partitioner = VgfPartitioner (self .compile_spec )
320- else :
321- raise ValueError ("compile spec doesn't target any Arm Partitioner" )
299+ arm_partitioner = create_partitioner (self .compile_spec )
322300 partition_stage = Partition (arm_partitioner )
323301 return super ().partition (partition_stage )
324302
@@ -328,7 +306,7 @@ def to_edge_transform_and_lower(
328306 partitioners : Optional [List [Partitioner ]] = None ,
329307 edge_compile_config : Optional [EdgeCompileConfig ] = None ,
330308 additional_checks : Optional [
331- List [Union [ DontPartition | DontPartitionModule | DontPartitionName ] ]
309+ List [DontPartition | DontPartitionModule | DontPartitionName ]
332310 ] = None ,
333311 transform_passes : Optional [
334312 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -341,20 +319,9 @@ def to_edge_transform_and_lower(
341319
342320 if to_edge_and_lower_stage is None :
343321 if partitioners is None :
344- if isinstance (self .compile_spec , TosaCompileSpec ):
345- arm_partitioner = TOSAPartitioner (
346- self .compile_spec , additional_checks
347- )
348- elif isinstance (self .compile_spec , EthosUCompileSpec ):
349- arm_partitioner = EthosUPartitioner (
350- self .compile_spec , additional_checks
351- )
352- elif isinstance (self .compile_spec , VgfCompileSpec ):
353- arm_partitioner = VgfPartitioner (
354- self .compile_spec , additional_checks
355- )
356- else :
357- raise ValueError ("compile spec doesn't target any Arm Partitioner" )
322+ arm_partitioner = create_partitioner (
323+ self .compile_spec , additional_checks
324+ )
358325 partitioners = [arm_partitioner ]
359326 to_edge_and_lower_stage = ToEdgeTransformAndLower (
360327 partitioners ,
@@ -731,22 +698,19 @@ def _get_tosa_operator_distribution(
731698 op_list = []
732699 id = 0
733700 while lowered_module := getattr (graph_module , f"lowered_module_{ id } " , None ):
734- for spec in lowered_module .compile_specs :
735- if spec .key != "output_format" :
736- continue
737- if spec .value == b"tosa" :
738- tosa_fb = lowered_module .processed_bytes
739- tosa_json = dbg_tosa_fb_to_json (tosa_fb )
740- for region in tosa_json ["regions" ]:
741- for block in region ["blocks" ]:
742- op_list .extend (
743- [operator ["op" ] for operator in block ["operators" ]]
744- )
745- break
746- elif spec .value == EthosUCompileSpec .get_output_format ().encode ():
747- return "Can not get operator distribution for Vela command stream."
748- else :
749- return f"Unknown output format '{ spec .value } '."
701+ compile_spec = parse_compile_spec (lowered_module .compile_specs )
702+ if isinstance (compile_spec , TosaCompileSpec ):
703+ tosa_fb = lowered_module .processed_bytes
704+ tosa_json = dbg_tosa_fb_to_json (tosa_fb )
705+ for region in tosa_json ["regions" ]:
706+ for block in region ["blocks" ]:
707+ op_list .extend ([operator ["op" ] for operator in block ["operators" ]])
708+ elif isinstance (compile_spec , EthosUCompileSpec ):
709+ return "Can not get operator distribution for Vela command stream."
710+ elif isinstance (compile_spec , VgfCompileSpec ):
711+ return "Can not get operator distribution for VGF."
712+ else :
713+ return f"Unknown output format '{ compile_spec .get_output_format ()} '."
750714 id += 1
751715 if id == 0 :
752716 return "No delegate with name 'lowered_module_0 found in graph module."
0 commit comments