2828
2929import torch .fx
3030import torch .utils ._pytree as pytree
31-
3231from executorch .backends .arm ._passes .arm_pass_manager import ArmPassManager
3332
3433from executorch .backends .arm .common .arm_compile_spec import ArmCompileSpec
35- from executorch .backends .arm .ethosu import EthosUCompileSpec , EthosUPartitioner
36- from executorch .backends .arm .quantizer import (
37- EthosUQuantizer ,
38- get_symmetric_quantization_config ,
39- TOSAQuantizer ,
40- VgfQuantizer ,
41- )
34+ from executorch .backends .arm .ethosu import EthosUCompileSpec
35+ from executorch .backends .arm .quantizer import get_symmetric_quantization_config
4236from executorch .backends .arm .test .runner_utils import (
4337 dbg_tosa_fb_to_json ,
4438 get_output_quantization_params ,
5347from executorch .backends .arm .tosa import TosaSpecification
5448from executorch .backends .arm .tosa .compile_spec import TosaCompileSpec
5549from executorch .backends .arm .tosa .mapping import extract_tensor_meta
56- from executorch .backends .arm .tosa .partitioner import TOSAPartitioner
5750
58- from executorch .backends .arm .vgf import VgfCompileSpec , VgfPartitioner
51+ from executorch .backends .arm .util ._factory import (
52+ create_partitioner ,
53+ create_quantizer ,
54+ parse_compile_spec ,
55+ )
56+ from executorch .backends .arm .vgf import VgfCompileSpec
5957
6058from executorch .backends .test .harness .error_statistics import ErrorStatistics
6159from executorch .backends .test .harness .stages import Stage , StageType
8381 _copy_module ,
8482 _update_exported_program_graph_module ,
8583)
86-
8784from tabulate import tabulate
8885
8986from torch .export .graph_signature import ExportGraphSignature , InputSpec , OutputSpec
@@ -103,26 +100,20 @@ def _dump_lowered_modules_artifact(
103100 artifact .exported_program ().graph_signature
104101 )
105102
106- def get_output_format (lowered_module ) -> str | None :
107- for spec in lowered_module .compile_specs :
108- if spec .key == "output_format" :
109- return spec .value .decode ()
110- return None
111-
112103 for node in graph_module .graph .nodes :
113104 if node .op == "get_attr" and node .name .startswith ("lowered_module_" ):
114105 lowered_module = getattr (graph_module , node .name )
115106 assert isinstance (
116107 lowered_module , LoweredBackendModule
117108 ), f"Attribute { node .name } must be of type LoweredBackendModule."
118109
119- output_format = get_output_format (lowered_module )
120- if output_format == "tosa" :
110+ compile_spec = parse_compile_spec (lowered_module . compile_specs )
111+ if isinstance ( compile_spec , TosaCompileSpec ) :
121112 tosa_fb = lowered_module .processed_bytes
122113 to_print = dbg_tosa_fb_to_json (tosa_fb )
123114 to_print = pformat (to_print , compact = True , indent = 1 )
124115 output += f"\n TOSA deserialized { node .name } : \n { to_print } \n "
125- elif output_format == EthosUCompileSpec . get_output_format ( ):
116+ elif isinstance ( compile_spec , EthosUCompileSpec ):
126117 vela_cmd_stream = lowered_module .processed_bytes
127118 output += f"\n Vela command stream { node .name } : \n { vela_cmd_stream } \n "
128119 else :
@@ -284,13 +275,7 @@ def quantize(
284275 quantize_stage : Optional [tester .Quantize ] = None ,
285276 ):
286277 if quantize_stage is None :
287- quantizer = None
288- if isinstance (self .compile_spec , TosaCompileSpec ):
289- quantizer = TOSAQuantizer (self .compile_spec )
290- elif isinstance (self .compile_spec , EthosUCompileSpec ):
291- quantizer = EthosUQuantizer (self .compile_spec )
292- elif isinstance (self .compile_spec , VgfCompileSpec ):
293- quantizer = VgfQuantizer (self .compile_spec )
278+ quantizer = create_quantizer (self .compile_spec )
294279 quantize_stage = tester .Quantize (
295280 quantizer ,
296281 get_symmetric_quantization_config (),
@@ -312,14 +297,7 @@ def to_edge(
312297
313298 def partition (self , partition_stage : Optional [Partition ] = None ):
314299 if partition_stage is None :
315- if isinstance (self .compile_spec , TosaCompileSpec ):
316- arm_partitioner = TOSAPartitioner (self .compile_spec )
317- elif isinstance (self .compile_spec , EthosUCompileSpec ):
318- arm_partitioner = EthosUPartitioner (self .compile_spec )
319- elif isinstance (self .compile_spec , VgfCompileSpec ):
320- arm_partitioner = VgfPartitioner (self .compile_spec )
321- else :
322- raise ValueError ("compile spec doesn't target any Arm Partitioner" )
300+ arm_partitioner = create_partitioner (self .compile_spec )
323301 partition_stage = Partition (arm_partitioner )
324302 return super ().partition (partition_stage )
325303
@@ -329,7 +307,7 @@ def to_edge_transform_and_lower(
329307 partitioners : Optional [List [Partitioner ]] = None ,
330308 edge_compile_config : Optional [EdgeCompileConfig ] = None ,
331309 additional_checks : Optional [
332- List [Union [ DontPartition | DontPartitionModule | DontPartitionName ] ]
310+ List [DontPartition | DontPartitionModule | DontPartitionName ]
333311 ] = None ,
334312 transform_passes : Optional [
335313 Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
@@ -343,20 +321,9 @@ def to_edge_transform_and_lower(
343321
344322 if to_edge_and_lower_stage is None :
345323 if partitioners is None :
346- if isinstance (self .compile_spec , TosaCompileSpec ):
347- arm_partitioner = TOSAPartitioner (
348- self .compile_spec , additional_checks
349- )
350- elif isinstance (self .compile_spec , EthosUCompileSpec ):
351- arm_partitioner = EthosUPartitioner (
352- self .compile_spec , additional_checks
353- )
354- elif isinstance (self .compile_spec , VgfCompileSpec ):
355- arm_partitioner = VgfPartitioner (
356- self .compile_spec , additional_checks
357- )
358- else :
359- raise ValueError ("compile spec doesn't target any Arm Partitioner" )
324+ arm_partitioner = create_partitioner (
325+ self .compile_spec , additional_checks
326+ )
360327 partitioners = [arm_partitioner ]
361328 to_edge_and_lower_stage = ToEdgeTransformAndLower (
362329 partitioners ,
@@ -743,22 +710,19 @@ def _get_tosa_operator_distribution(
743710 op_list = []
744711 id = 0
745712 while lowered_module := getattr (graph_module , f"lowered_module_{ id } " , None ):
746- for spec in lowered_module .compile_specs :
747- if spec .key != "output_format" :
748- continue
749- if spec .value == b"tosa" :
750- tosa_fb = lowered_module .processed_bytes
751- tosa_json = dbg_tosa_fb_to_json (tosa_fb )
752- for region in tosa_json ["regions" ]:
753- for block in region ["blocks" ]:
754- op_list .extend (
755- [operator ["op" ] for operator in block ["operators" ]]
756- )
757- break
758- elif spec .value == EthosUCompileSpec .get_output_format ().encode ():
759- return "Can not get operator distribution for Vela command stream."
760- else :
761- return f"Unknown output format '{ spec .value } '."
713+ compile_spec = parse_compile_spec (lowered_module .compile_specs )
714+ if isinstance (compile_spec , TosaCompileSpec ):
715+ tosa_fb = lowered_module .processed_bytes
716+ tosa_json = dbg_tosa_fb_to_json (tosa_fb )
717+ for region in tosa_json ["regions" ]:
718+ for block in region ["blocks" ]:
719+ op_list .extend ([operator ["op" ] for operator in block ["operators" ]])
720+ elif isinstance (compile_spec , EthosUCompileSpec ):
721+ return "Can not get operator distribution for Vela command stream."
722+ elif isinstance (compile_spec , VgfCompileSpec ):
723+ return "Can not get operator distribution for VGF."
724+ else :
725+ return f"Unknown output format '{ compile_spec .get_output_format ()} '."
762726 id += 1
763727 if id == 0 :
764728 return "No delegate with name 'lowered_module_0 found in graph module."
0 commit comments