7
7
import copy
8
8
import logging
9
9
from abc import ABC , abstractmethod
10
- from typing import Any , Callable , Dict , List , Optional , Sequence
10
+ from collections import defaultdict
11
+ from typing import Any , Callable , Dict , List , Optional
11
12
12
13
import torch
13
14
from executorch .devtools .backend_debug import get_delegation_info
14
- from executorch .exir import EdgeCompileConfig
15
+ from executorch .exir import EdgeCompileConfig , ExportedProgram
15
16
from executorch .exir .backend .backend_api import validation_disabled
16
17
from executorch .exir .program import to_edge , to_edge_transform_and_lower
17
- from executorch .exir .program ._program import _transform
18
18
from executorch .export .recipe import LoweringRecipe , QuantizationRecipe
19
19
from executorch .export .types import StageType
20
20
from torch import nn
@@ -107,10 +107,12 @@ class TorchExportStage(Stage):
107
107
108
108
def __init__ (
109
109
self ,
110
- pre_edge_transform_passes : Optional [List [PassType ]] = None ,
110
+ aten_transform_passes : Optional [
111
+ List [Callable [[str , ExportedProgram ], ExportedProgram ]]
112
+ ] = None ,
111
113
) -> None :
112
114
super ().__init__ ()
113
- self ._pre_edge_transform_passes = pre_edge_transform_passes
115
+ self ._aten_transform_passes = aten_transform_passes
114
116
115
117
@property
116
118
def stage_type (self ) -> str :
@@ -149,9 +151,13 @@ def run(self, artifact: PipelineArtifact) -> None:
149
151
)
150
152
151
153
# Apply pre-edge transform passes if available
152
- for pass_ in self ._pre_edge_transform_passes or []:
153
- exported_programs [method_name ] = _transform (
154
- exported_programs [method_name ], pass_
154
+ for pass_ in self ._aten_transform_passes or []:
155
+ if not callable (pass_ ):
156
+ raise ValueError (
157
+ "Aten transform passes must be a callable that can transform and return an exported program"
158
+ )
159
+ exported_programs [method_name ] = pass_ (
160
+ method_name , exported_programs [method_name ]
155
161
)
156
162
157
163
self ._artifact = artifact .copy_with_new_data (exported_programs )
@@ -165,7 +171,9 @@ class EdgeTransformAndLowerStage(Stage):
165
171
def __init__ (
166
172
self ,
167
173
partitioners : Optional [List [Any ]] = None ,
168
- transform_passes : Optional [Sequence [Callable [[Any ], Optional [Any ]]]] = None ,
174
+ transform_passes : (
175
+ None | List [Callable [[str , ExportedProgram ], List [PassType ]]]
176
+ ) = None ,
169
177
compile_config : Optional [Any ] = None ,
170
178
) -> None :
171
179
self ._partitioners = partitioners
@@ -205,11 +213,28 @@ def run(self, artifact: PipelineArtifact) -> None:
205
213
constant_methods = artifact .get_context ("constant_methods" )
206
214
generate_etrecord = artifact .get_context ("generate_etrecord" , False )
207
215
216
+ # per method transform passes
217
+ transform_passes = defaultdict (list )
218
+ for method_name , ep in exported_programs .items ():
219
+ # Resolve transform passes from callable
220
+ for pass_ in self ._transform_passes or []:
221
+ if not callable (pass_ ):
222
+ raise ValueError (
223
+ "Transform passes must be a callable that resolves to a list of passes"
224
+ )
225
+ passes = pass_ (method_name , ep )
226
+ if isinstance (passes , list ):
227
+ transform_passes [method_name ].extend (passes )
228
+ else :
229
+ raise ValueError (
230
+ "Transform passes must be a callable that resolves to a list of passes"
231
+ )
232
+
208
233
with validation_disabled ():
209
234
edge_program_manager = to_edge_transform_and_lower (
210
235
exported_programs ,
211
236
partitioner = self ._partitioners ,
212
- transform_passes = self . _transform_passes ,
237
+ transform_passes = transform_passes ,
213
238
constant_methods = constant_methods ,
214
239
compile_config = self ._compile_config ,
215
240
generate_etrecord = generate_etrecord ,
@@ -396,7 +421,7 @@ def run(self, artifact: PipelineArtifact) -> None:
396
421
captured_graph = torch .export .export (model , inputs , strict = True ).module ()
397
422
398
423
quantizer = self ._get_quantizer_for_prepare_pt2e (
399
- self ._quantization_recipe .quantizers
424
+ self ._quantization_recipe .quantizers # pyre-ignore
400
425
)
401
426
prepared_model = prepare_pt2e (captured_graph , quantizer )
402
427
@@ -471,7 +496,9 @@ class ToBackendStage(Stage):
471
496
def __init__ (
472
497
self ,
473
498
partitioners : Optional [List [Any ]] = None ,
474
- transform_passes : Optional [Sequence [Callable [[Any ], Optional [Any ]]]] = None ,
499
+ transform_passes : (
500
+ None | List [Callable [[str , ExportedProgram ], List [PassType ]]]
501
+ ) = None ,
475
502
) -> None :
476
503
super ().__init__ ()
477
504
self ._partitioners = partitioners
@@ -513,11 +540,24 @@ def run(self, artifact: PipelineArtifact) -> None:
513
540
if edge_program_manager is None :
514
541
raise RuntimeError ("Edge program manager is not set." )
515
542
516
- # Apply transform passes if available
517
- if self ._transform_passes :
518
- edge_program_manager = edge_program_manager .transform (
519
- self ._transform_passes
520
- )
543
+ # per method transform passes
544
+ transform_passes = defaultdict (list )
545
+ for method_name in edge_program_manager .methods :
546
+ # Resolve transform passes if it's a callable
547
+ ep = edge_program_manager .exported_program (method_name )
548
+ for pass_ in self ._transform_passes or []:
549
+ if not callable (pass_ ):
550
+ raise ValueError (
551
+ "Transform passes must be a callable that resolves to a list of passes"
552
+ )
553
+ passes = pass_ (method_name , ep )
554
+ if isinstance (passes , list ):
555
+ transform_passes [method_name ].extend (passes )
556
+ else :
557
+ raise ValueError ("Transform passes must return list of passes" )
558
+
559
+ # Apply transform passes
560
+ edge_program_manager = edge_program_manager .transform (transform_passes )
521
561
522
562
# Apply partitioners if available
523
563
if self ._partitioners is not None and len (self ._partitioners ) > 0 :
0 commit comments