61
61
from executorch .backends .arm .tosa_partitioner import TOSAPartitioner
62
62
from executorch .backends .arm .tosa_specification import TosaSpecification
63
63
64
+ from executorch .backends .test .harness .stages import Stage , StageType
64
65
from executorch .backends .xnnpack .test .tester import Tester
65
66
from executorch .devtools .backend_debug import get_delegation_info
66
67
@@ -259,10 +260,13 @@ def wrapped_ep_pass(ep: ExportedProgram) -> ExportedProgram:
259
260
super ().run (artifact , inputs )
260
261
261
262
262
- class InitialModel (tester . Stage ):
263
+ class InitialModel (Stage ):
263
264
def __init__ (self , model : torch .nn .Module ):
264
265
self .model = model
265
266
267
+ def stage_type (self ) -> StageType :
268
+ return StageType .INITIAL_MODEL
269
+
266
270
def run (self , artifact , inputs = None ) -> None :
267
271
pass
268
272
@@ -305,13 +309,13 @@ def __init__(
305
309
self .constant_methods = constant_methods
306
310
self .compile_spec = compile_spec
307
311
super ().__init__ (model , example_inputs , dynamic_shapes )
308
- self .pipeline [self . stage_name ( InitialModel ) ] = [
309
- self . stage_name ( tester . Quantize ) ,
310
- self . stage_name ( tester . Export ) ,
312
+ self .pipeline [StageType . INITIAL_MODEL ] = [
313
+ StageType . QUANTIZE ,
314
+ StageType . EXPORT ,
311
315
]
312
316
313
317
# Initial model needs to be set as a *possible* but not yet added Stage, therefore add None entry.
314
- self .stages [self . stage_name ( InitialModel ) ] = None
318
+ self .stages [StageType . INITIAL_MODEL ] = None
315
319
self ._run_stage (InitialModel (self .original_module ))
316
320
317
321
def quantize (
@@ -413,7 +417,7 @@ def serialize(
413
417
return super ().serialize (serialize_stage )
414
418
415
419
def is_quantized (self ) -> bool :
416
- return self .stages [self . stage_name ( tester . Quantize ) ] is not None
420
+ return self .stages [StageType . QUANTIZE ] is not None
417
421
418
422
def run_method_and_compare_outputs (
419
423
self ,
@@ -442,18 +446,16 @@ def run_method_and_compare_outputs(
442
446
"""
443
447
444
448
if not run_eager_mode :
445
- edge_stage = self .stages [self . stage_name ( tester . ToEdge ) ]
449
+ edge_stage = self .stages [StageType . TO_EDGE ]
446
450
if edge_stage is None :
447
- edge_stage = self .stages [
448
- self .stage_name (tester .ToEdgeTransformAndLower )
449
- ]
451
+ edge_stage = self .stages [StageType .TO_EDGE_TRANSFORM_AND_LOWER ]
450
452
assert (
451
453
edge_stage is not None
452
454
), "To compare outputs, at least the ToEdge or ToEdgeTransformAndLower stage needs to be run."
453
455
else :
454
456
# Run models in eager mode. We do this when we want to check that the passes
455
457
# are numerically accurate and the exported graph is correct.
456
- export_stage = self .stages [self . stage_name ( tester . Export ) ]
458
+ export_stage = self .stages [StageType . EXPORT ]
457
459
assert (
458
460
export_stage is not None
459
461
), "To compare outputs in eager mode, the model must be at Export stage"
@@ -463,11 +465,11 @@ def run_method_and_compare_outputs(
463
465
is_quantized = self .is_quantized ()
464
466
465
467
if is_quantized :
466
- reference_stage = self .stages [self . stage_name ( tester . Quantize ) ]
468
+ reference_stage = self .stages [StageType . QUANTIZE ]
467
469
else :
468
- reference_stage = self .stages [self . stage_name ( InitialModel ) ]
470
+ reference_stage = self .stages [StageType . INITIAL_MODEL ]
469
471
470
- exported_program = self .stages [self . stage_name ( tester . Export ) ].artifact
472
+ exported_program = self .stages [StageType . EXPORT ].artifact
471
473
output_nodes = get_output_nodes (exported_program )
472
474
473
475
output_qparams = get_output_quantization_params (output_nodes )
@@ -477,7 +479,7 @@ def run_method_and_compare_outputs(
477
479
quantization_scales .append (getattr (output_qparams [node ], "scale" , None ))
478
480
479
481
logger .info (
480
- f"Comparing Stage '{ self . stage_name ( test_stage )} ' with Stage '{ self . stage_name ( reference_stage )} '"
482
+ f"Comparing Stage '{ test_stage . stage_type ( )} ' with Stage '{ reference_stage . stage_type ( )} '"
481
483
)
482
484
483
485
# Loop inputs and compare reference stage with the compared stage.
@@ -528,14 +530,12 @@ def get_graph(self, stage: str | None = None) -> Graph:
528
530
stage = self .cur
529
531
artifact = self .get_artifact (stage )
530
532
if (
531
- self .cur == self . stage_name ( tester . ToEdge )
532
- or self .cur == self . stage_name ( Partition )
533
- or self .cur == self . stage_name ( ToEdgeTransformAndLower )
533
+ self .cur == StageType . TO_EDGE
534
+ or self .cur == StageType . PARTITION
535
+ or self .cur == StageType . TO_EDGE_TRANSFORM_AND_LOWER
534
536
):
535
537
graph = artifact .exported_program ().graph
536
- elif self .cur == self .stage_name (tester .Export ) or self .cur == self .stage_name (
537
- tester .Quantize
538
- ):
538
+ elif self .cur == StageType .EXPORT or self .cur == StageType .QUANTIZE :
539
539
graph = artifact .graph
540
540
else :
541
541
raise RuntimeError (
@@ -556,13 +556,13 @@ def dump_operator_distribution(
556
556
Returns self for daisy-chaining.
557
557
"""
558
558
line = "#" * 10
559
- to_print = f"{ line } { self .cur . capitalize () } Operator Distribution { line } \n "
559
+ to_print = f"{ line } { self .cur } Operator Distribution { line } \n "
560
560
561
561
if (
562
562
self .cur
563
563
in (
564
- self . stage_name ( tester . Partition ) ,
565
- self . stage_name ( ToEdgeTransformAndLower ) ,
564
+ StageType . PARTITION ,
565
+ StageType . TO_EDGE_TRANSFORM_AND_LOWER ,
566
566
)
567
567
and print_table
568
568
):
@@ -602,9 +602,7 @@ def dump_dtype_distribution(
602
602
"""
603
603
604
604
line = "#" * 10
605
- to_print = (
606
- f"{ line } { self .cur .capitalize ()} Placeholder Dtype Distribution { line } \n "
607
- )
605
+ to_print = f"{ line } { self .cur } Placeholder Dtype Distribution { line } \n "
608
606
609
607
graph = self .get_graph (self .cur )
610
608
tosa_spec = get_tosa_spec (self .compile_spec )
@@ -653,7 +651,7 @@ def run_transform_for_annotation_pipeline(
653
651
stage = self .cur
654
652
# We need to clone the artifact in order to ensure that the state_dict is preserved after passes are run.
655
653
artifact = self .get_artifact (stage )
656
- if self .cur == self . stage_name ( tester . Export ) :
654
+ if self .cur == StageType . EXPORT :
657
655
new_gm = ArmPassManager (get_tosa_spec (self .compile_spec )).transform_for_annotation_pipeline ( # type: ignore[arg-type]
658
656
graph_module = artifact .graph_module
659
657
)
0 commit comments