77
88# pyre-unsafe
99
10+
11+ from collections import defaultdict
12+
1013import executorch .backends .arm .tosa .dialect # noqa: unused
1114from executorch .backends .arm ._passes import (
1215 AddBiasPass ,
9497 UnsqueezeScalarPlaceholdersPass ,
9598)
9699
100+ from executorch .backends .arm ._passes .arm_pass import ArmPass
97101from executorch .backends .arm .tosa .specification import (
98102 TosaLoweringContext ,
99103 TosaSpecification ,
@@ -115,6 +119,32 @@ def __init__(self, tosa_spec: TosaSpecification) -> None:
115119 self .tosa_spec = tosa_spec
116120 super ().__init__ ()
117121
122+ def validate_constraints_mandatory (self ):
123+ """
124+ Validates that necessary passes have run before transforming to backend.
125+
126+ Note that this differs from the original validate_constraints function, which
127+ only checks the order of passes.
128+ """
129+ passes_to_run = defaultdict (list )
130+
131+ for current_pass in self .passes :
132+ current_pass_name = ArmPass .get_name (current_pass )
133+ for required_pass_name in ArmPass .get_required_passes (current_pass ):
134+ passes_to_run [required_pass_name ].append (current_pass_name )
135+
136+ passes_to_run .pop (current_pass_name , None )
137+
138+ if len (passes_to_run ) > 0 :
139+ error_msg = "The following constraints for passes are not met:\n "
140+ for required_pass , requiring_passes in passes_to_run .items ():
141+ for requiring_pass in requiring_passes :
142+ error_msg += (
143+ f" - { required_pass } must run after { requiring_pass } \n "
144+ )
145+
146+ raise RuntimeError (error_msg )
147+
118148 def _transform (self , graph_module : GraphModule ):
119149 with TosaLoweringContext (self .tosa_spec ):
120150 return self (graph_module ).graph_module
@@ -125,7 +155,6 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
125155 self .add_pass (RemoveGetItemPass ())
126156 self .add_pass (ConvertSplitToSlicePass ())
127157 self .add_pass (ConvertMmToBmmPass ())
128- self .add_pass (DecomposeLinearVectorNormPass ())
129158 self .add_pass (
130159 DecomposeMeanDimPass (exported_program .graph_module , self .tosa_spec )
131160 )
@@ -175,6 +204,7 @@ def _tosa_INT_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
175204 self .add_pass (RemoveNoopPass ())
176205 self .add_pass (InsertRescalePass ())
177206
207+ self .validate_constraints_mandatory ()
178208 return self ._transform (exported_program .graph_module )
179209
180210 def _tosa_FP_pipeline (self , exported_program : ExportedProgram ) -> GraphModule :
@@ -258,6 +288,7 @@ def _tosa_FP_pipeline(self, exported_program: ExportedProgram) -> GraphModule:
258288 self .add_pass (RemoveNoopPass ())
259289 self .add_pass (InsertRescalePass ())
260290
291+ self .validate_constraints_mandatory ()
261292 return self ._transform (exported_program .graph_module )
262293
263294 def transform_to_backend_pipeline (self , exported_program : ExportedProgram ):
0 commit comments