@@ -240,8 +240,29 @@ def _transform(
240
240
isinstance (p , (list , Verifier )) for p in passes
241
241
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: { list (passes )} "
242
242
243
- pm = PassManager (list (passes ))
244
- res = pm (self .graph_module )
243
+ return _transform_with_pass_manager (
244
+ self , PassManager (list (passes )), override_verifiers
245
+ )
246
+
247
+
248
+ def _transform_with_pass_manager (
249
+ self ,
250
+ pass_manager : PassManager ,
251
+ override_verifiers : None | list [Type [Verifier ]] = None ,
252
+ ) -> "ExportedProgram" :
253
+ """
254
+ Transforms the program using the provided pass_manager.
255
+
256
+ Args:
257
+ self: The ExportedProgram instance to transform
258
+ pass_manager: An instance of PassManager to apply transformations.
259
+ override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
260
+ This is needed if the transforms yields illegal graph that the default verifier cannot handle.
261
+
262
+ Returns:
263
+ ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
264
+ """
265
+ res = pass_manager (self .graph_module )
245
266
transformed_gm = res .graph_module if res is not None else self .graph_module
246
267
assert transformed_gm is not None
247
268
@@ -1230,7 +1251,7 @@ def collect_named_data_store_outputs(
1230
1251
def to_edge_transform_and_lower ( # noqa: C901
1231
1252
programs : Union [ExportedProgram , Dict [str , ExportedProgram ]],
1232
1253
transform_passes : Optional [
1233
- Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]]
1254
+ Union [Sequence [PassType ], Dict [str , Sequence [PassType ]], PassManager ]
1234
1255
] = None ,
1235
1256
partitioner : Optional [
1236
1257
Union [List [Partitioner ], Dict [str , List [Partitioner ]]]
@@ -1259,11 +1280,15 @@ def to_edge_transform_and_lower( # noqa: C901
1259
1280
to their corresponding ExportedPrograms. If only a single ExportedProgram is
1260
1281
provided it will be assigned the name "forward".
1261
1282
1262
- transform_passes: The passes can either be a list of passes, or a dictionary
1263
- mapping method names to lists of passes. If it is just a list of passes, all methods
1264
- in the given EdgeProgramManager will be transformed with the provided passes. If it
1265
- is a dictionary, only method names specified in the dictionary will be transformed
1266
- with their corresponding passes.
1283
+ transform_passes: The transform_passes can be one of:
1284
+ 1) a list of passes -
1285
+ all methods in the given EdgeProgramManager will be transformed with the provided passes.
1286
+ 2) a dictionary -
1287
+ only method names specified in the dictionary will be transformed
1288
+ with their corresponding passes
1289
+ 3) an instance of a PassManager -
1290
+ all methods in the given EdgeProgramManager will be
1291
+ transformed with the given PassManager instance.
1267
1292
1268
1293
partitioner: The partitioner can either be a Partitioner subclass instance, or a
1269
1294
dictionary mapping method names to Partitioner subclass instance. If it is a
@@ -1493,19 +1518,23 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
1493
1518
@et_logger ("transform" )
1494
1519
def transform (
1495
1520
self ,
1496
- passes : Union [Sequence [PassType ], Dict [str , Sequence [PassType ]]],
1521
+ passes : Union [Sequence [PassType ], Dict [str , Sequence [PassType ]], PassManager ],
1497
1522
compile_config : Optional [EdgeCompileConfig ] = None ,
1498
1523
) -> "EdgeProgramManager" :
1499
1524
"""
1500
1525
Transforms the program according to the provided passes.
1501
1526
1502
1527
Args:
1503
- passes: The passes can either be a list of passes, or a
1504
- dictionary mapping method names to lists of passes. If it is
1505
- just a list of passes, all methods in the given EdgeProgramManager
1506
- will be transformed with the provided passes. If it is a
1507
- dictionary, only method names specified in the dictionary will be
1508
- transformed with their corresponding passes.
1528
+ passes: This param can be one of:
1529
+ 1) a list of passes -
1530
+ all methods in the given EdgeProgramManager
1531
+ will be transformed with the provided passes.
1532
+ 2) a dictionary mapping method names to lists of passes -
1533
+ only method names specified in the dictionary will be
1534
+ transformed with their corresponding passes.
1535
+ 3) a PassManager instance -
1536
+ all methods in the given EdgeProgramManager will be
1537
+ transformed with the given PassManager instance.
1509
1538
compile_config: Compile config to use for veriy the correctness of model
1510
1539
graph after each pass. If not specified, the compile config of the
1511
1540
calling EdgeProgramManager will be used. It will be used in as compile
@@ -1515,24 +1544,44 @@ def transform(
1515
1544
EdgeProgramManager: A copy of the calling EdgeProgramManager with the
1516
1545
transformations applied.
1517
1546
"""
1547
+
1518
1548
compile_config = compile_config or self .compile_config
1519
1549
new_programs : Dict [str , ExportedProgram ] = {}
1550
+
1551
+ # Cast passes parameter upfront.
1552
+ passes_seq : Optional [Sequence [PassType ]] = None
1553
+ passes_dict : Optional [Dict [str , Sequence [PassType ]]] = None
1554
+ pass_manager : Optional [PassManager ] = None
1555
+
1556
+ if isinstance (passes , Sequence ):
1557
+ passes_seq = passes
1520
1558
if isinstance (passes , dict ):
1521
- for name , program in self ._edge_programs .items ():
1522
- if name in passes .keys ():
1523
- new_programs [name ] = _transform (program , * passes [name ])
1524
- EXIREdgeDialectVerifier (edge_compile_config = compile_config )(
1525
- new_programs [name ].graph_module
1526
- )
1527
- else :
1528
- new_programs [name ] = copy .deepcopy (program )
1559
+ passes_dict = passes
1560
+ if isinstance (passes , PassManager ):
1561
+ pass_manager = passes
1529
1562
1530
- else : # apply passes to every method
1531
- for name , program in self ._edge_programs .items ():
1532
- new_programs [name ] = _transform (program , * passes )
1533
- EXIREdgeDialectVerifier (edge_compile_config = compile_config )(
1534
- new_programs [name ].graph_module
1535
- )
1563
+ for name , program in self ._edge_programs .items ():
1564
+ # If the method name is enforced, but not matched, we skip transformation.
1565
+ if (
1566
+ isinstance (passes , dict )
1567
+ and passes_dict
1568
+ and name not in passes_dict .keys ()
1569
+ ):
1570
+ new_programs [name ] = copy .deepcopy (program )
1571
+ continue
1572
+
1573
+ # Depending on the passes parameter, call the corresponding transform function.
1574
+ if passes_seq is not None :
1575
+ new_programs [name ] = _transform (program , * passes_seq )
1576
+ elif passes_dict is not None :
1577
+ new_programs [name ] = _transform (program , * passes_dict [name ])
1578
+ elif pass_manager is not None :
1579
+ new_programs [name ] = _transform_with_pass_manager (program , pass_manager )
1580
+
1581
+ # Verify the correctness of model graph after each transformation.
1582
+ EXIREdgeDialectVerifier (edge_compile_config = compile_config )(
1583
+ new_programs [name ].graph_module
1584
+ )
1536
1585
1537
1586
return EdgeProgramManager (
1538
1587
new_programs , copy .deepcopy (self ._config_methods ), compile_config
0 commit comments