77import copy
88import logging
99from abc import ABC , abstractmethod
10- from typing import Any , Callable , Dict , List , Optional , Sequence
10+ from collections import defaultdict
11+ from typing import Any , Callable , Dict , List , Optional
1112
1213import torch
1314from executorch .devtools .backend_debug import get_delegation_info
14- from executorch .exir import EdgeCompileConfig
15+ from executorch .exir import EdgeCompileConfig , ExportedProgram
1516from executorch .exir .backend .backend_api import validation_disabled
1617from executorch .exir .program import to_edge , to_edge_transform_and_lower
17- from executorch .exir .program ._program import _transform
1818from executorch .export .recipe import LoweringRecipe , QuantizationRecipe
1919from executorch .export .types import StageType
2020from torch import nn
@@ -107,10 +107,12 @@ class TorchExportStage(Stage):
107107
108108 def __init__ (
109109 self ,
110- pre_edge_transform_passes : Optional [List [PassType ]] = None ,
110+ aten_transform_passes : Optional [
111+ List [Callable [[str , ExportedProgram ], ExportedProgram ]]
112+ ] = None ,
111113 ) -> None :
112114 super ().__init__ ()
113- self ._pre_edge_transform_passes = pre_edge_transform_passes
115+ self ._aten_transform_passes = aten_transform_passes
114116
115117 @property
116118 def stage_type (self ) -> str :
@@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None:
149151 )
150152
151153 # Apply pre-edge transform passes if available
152- for pass_ in self ._pre_edge_transform_passes or []:
153- exported_programs [method_name ] = _transform (
154- exported_programs [method_name ], pass_
154+ for pass_ in self ._aten_transform_passes or []:
155+ if not callable (pass_ ):
156+ raise ValueError (
157+ "Aten transform passes must be a callable that can transform and return an exported program"
158+ )
159+ exported_programs [method_name ] = pass_ (
160+ method_name , exported_programs [method_name ]
155161 )
156162
157163 self ._artifact = artifact .copy_with_new_data (exported_programs )
@@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage):
165171 def __init__ (
166172 self ,
167173 partitioners : Optional [List [Any ]] = None ,
168- transform_passes : Optional [Sequence [Callable [[Any ], Optional [Any ]]]] = None ,
174+ transform_passes : (
175+ None | List [Callable [[str , ExportedProgram ], List [PassType ]]]
176+ ) = None ,
169177 compile_config : Optional [Any ] = None ,
170178 ) -> None :
171179 self ._partitioners = partitioners
@@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None:
205213 constant_methods = artifact .get_context ("constant_methods" )
206214 generate_etrecord = artifact .get_context ("generate_etrecord" , False )
207215
216+ # per method transform passes
217+ transform_passes = defaultdict (list )
218+ for method_name , ep in exported_programs .items ():
219+ # Resolve transform passes from callable
220+ for pass_ in self ._transform_passes or []:
221+ if not callable (pass_ ):
222+ raise ValueError (
223+ "Transform passes must be a callable that resolves to a list of passes"
224+ )
225+ passes = pass_ (method_name , ep )
226+ if isinstance (passes , list ):
227+ transform_passes [method_name ].extend (passes )
228+ else :
229+ raise ValueError (
230+ "Transform passes must be a callable that resolves to a list of passes"
231+ )
232+
208233 with validation_disabled ():
209234 edge_program_manager = to_edge_transform_and_lower (
210235 exported_programs ,
211236 partitioner = self ._partitioners ,
212- transform_passes = self . _transform_passes ,
237+ transform_passes = transform_passes ,
213238 constant_methods = constant_methods ,
214239 compile_config = self ._compile_config ,
215240 generate_etrecord = generate_etrecord ,
@@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None:
396421 captured_graph = torch .export .export (model , inputs , strict = True ).module ()
397422
398423 quantizer = self ._get_quantizer_for_prepare_pt2e (
399- self ._quantization_recipe .quantizers
424+ self ._quantization_recipe .quantizers # pyre-ignore
400425 )
401426 prepared_model = prepare_pt2e (captured_graph , quantizer )
402427
@@ -471,7 +496,9 @@ class ToBackendStage(Stage):
471496 def __init__ (
472497 self ,
473498 partitioners : Optional [List [Any ]] = None ,
474- transform_passes : Optional [Sequence [Callable [[Any ], Optional [Any ]]]] = None ,
499+ transform_passes : (
500+ None | List [Callable [[str , ExportedProgram ], List [PassType ]]]
501+ ) = None ,
475502 ) -> None :
476503 super ().__init__ ()
477504 self ._partitioners = partitioners
@@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None:
513540 if edge_program_manager is None :
514541 raise RuntimeError ("Edge program manager is not set." )
515542
516- # Apply transform passes if available
517- if self ._transform_passes :
518- edge_program_manager = edge_program_manager .transform (
519- self ._transform_passes
520- )
543+ # per method transform passes
544+ transform_passes = defaultdict (list )
545+ for method_name in edge_program_manager .methods :
546+ # Resolve transform passes if it's a callable
547+ ep = edge_program_manager .exported_program (method_name )
548+ for pass_ in self ._transform_passes or []:
549+ if not callable (pass_ ):
550+ raise ValueError (
551+ "Transform passes must be a callable that resolves to a list of passes"
552+ )
553+ passes = pass_ (method_name , ep )
554+ if isinstance (passes , list ):
555+ transform_passes [method_name ].extend (passes )
556+ else :
557+ raise ValueError ("Transform passes must return list of passes" )
558+
559+ # Apply transform passes
560+ edge_program_manager = edge_program_manager .transform (transform_passes )
521561
522562 # Apply partitioners if available
523563 if self ._partitioners is not None and len (self ._partitioners ) > 0 :
0 commit comments