1+ from typing import Any , Callable , List , Optional , Sequence , Type , Tuple , Union
2+
3+ import torch
4+
5+ from executorch .backends .test .harness .stages .stage import Stage , StageType
6+ from executorch .exir import (
7+ EdgeCompileConfig ,
8+ EdgeProgramManager ,
9+ )
10+ from executorch .exir .backend .partitioner import Partitioner
11+ from torch ._export .pass_base import PassType
12+ from torch .export import ExportedProgram
13+
14+ class RunPasses (Stage ):
15+ def __init__ (
16+ self ,
17+ pass_manager_cls : Type ,
18+ pass_list : Optional [List [Type [PassType ]]] = None ,
19+ pass_functions : Optional [List [Callable ]] = None ,
20+ ):
21+ self .pass_manager_cls = pass_manager_cls
22+ self .pass_list = pass_list
23+ self .pass_functions = pass_functions
24+ self .edge_or_aten_program = None
25+
26+ def stage_type (self ) -> StageType :
27+ return StageType .RUN_PASSES
28+
29+ def run (
30+ self , artifact : Union [EdgeProgramManager , ExportedProgram ], inputs = None
31+ ) -> None :
32+ if isinstance (artifact , EdgeProgramManager ):
33+ self .edge_or_aten_program = artifact
34+ if self .pass_list :
35+ pass_manager = self .pass_manager_cls (
36+ artifact .exported_program (), self .pass_list
37+ )
38+ self .edge_or_aten_program ._edge_programs ["forward" ] = (
39+ pass_manager .transform ()
40+ )
41+ if self .pass_functions :
42+ assert isinstance (self .pass_functions , list )
43+ for pass_function in self .pass_functions :
44+ self .edge_or_aten_program ._edge_programs ["forward" ] = pass_function (
45+ self .edge_or_aten_program .exported_program ()
46+ )
47+ else :
48+ transformed_ep = artifact
49+ if self .pass_list :
50+ assert isinstance (self .pass_list , list )
51+ for pass_ in self .pass_list :
52+ transformed_ep = _transform (transformed_ep , pass_ ())
53+
54+ if self .pass_functions :
55+ assert isinstance (self .pass_functions , list )
56+ for pass_function in self .pass_functions :
57+ transformed_ep = pass_function (transformed_ep )
58+
59+ self .edge_or_aten_program = transformed_ep
60+
61+ @property
62+ def artifact (self ) -> Union [EdgeProgramManager , ExportedProgram ]:
63+ return self .edge_or_aten_program
64+
65+ @property
66+ def graph_module (self ) -> str :
67+ if isinstance (self .edge_or_aten_program , EdgeProgramManager ):
68+ return self .edge_or_aten_program .exported_program ().graph_module
69+ else :
70+ return self .edge_or_aten_program .graph_module
0 commit comments