Skip to content

Commit 0c18261

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Extend tranform() method params with PassManager. (#13140)
Summary: Pull Request resolved: #13140 There is a user request to pass a custom PassManager to the EdgeProgramManager::transform(). This diff extends method API appropriately. Differential Revision: D79292048
1 parent 047587e commit 0c18261

File tree

1 file changed

+50
-14
lines changed

1 file changed

+50
-14
lines changed

exir/program/_program.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,27 @@ def _transform(
240240
isinstance(p, (list, Verifier)) for p in passes
241241
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
242242

243-
pm = PassManager(list(passes))
244-
res = pm(self.graph_module)
243+
return _transform_with_pass_manager(self, PassManager(list(passes)), override_verifiers)
244+
245+
246+
def _transform_with_pass_manager(
247+
self,
248+
pass_manager: PassManager,
249+
override_verifiers: None | list[Type[Verifier]] = None,
250+
) -> "ExportedProgram":
251+
"""
252+
Transforms the program using the provided pass_manager.
253+
254+
Args:
255+
self: The ExportedProgram instance to transform
256+
pass_manager: An instance of PassManager to apply transformations.
257+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
258+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
259+
260+
Returns:
261+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
262+
"""
263+
res = pass_manager(self.graph_module)
245264
transformed_gm = res.graph_module if res is not None else self.graph_module
246265
assert transformed_gm is not None
247266

@@ -1493,19 +1512,23 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
14931512
@et_logger("transform")
14941513
def transform(
14951514
self,
1496-
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
1515+
passes_or_pass_manager: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager],
14971516
compile_config: Optional[EdgeCompileConfig] = None,
14981517
) -> "EdgeProgramManager":
14991518
"""
15001519
Transforms the program according to the provided passes.
15011520
15021521
Args:
1503-
passes: The passes can either be a list of passes, or a
1504-
dictionary mapping method names to lists of passes. If it is
1505-
just a list of passes, all methods in the given EdgeProgramManager
1506-
will be transformed with the provided passes. If it is a
1507-
dictionary, only method names specified in the dictionary will be
1508-
transformed with their corresponding passes.
1522+
passes_or_pass_manager: Can either one of:
1523+
1) a list of passes -
1524+
all methods in the given EdgeProgramManager
1525+
will be transformed with the provided passes.
1526+
2) a dictionary mapping method names to lists of passes -
1527+
only method names specified in the dictionary will be
1528+
transformed with their corresponding passes.
1529+
3) a PassManager instance -
1530+
all methods in the given EdgeProgramManager will be
1531+
transformed with the given PassManager instance.
15091532
compile_config: Compile config to use for veriy the correctness of model
15101533
graph after each pass. If not specified, the compile config of the
15111534
calling EdgeProgramManager will be used. It will be used in as compile
@@ -1517,19 +1540,32 @@ def transform(
15171540
"""
15181541
compile_config = compile_config or self.compile_config
15191542
new_programs: Dict[str, ExportedProgram] = {}
1520-
if isinstance(passes, dict):
1543+
if isinstance(passes_or_pass_manager, dict):
1544+
# Transform only those methods which match the keys in the provided dictionary.
1545+
passes_dict : Dict[str, Sequence[PassType]] = passes_or_pass_manager
15211546
for name, program in self._edge_programs.items():
1522-
if name in passes.keys():
1523-
new_programs[name] = _transform(program, *passes[name])
1547+
if name in passes_dict.keys():
1548+
new_programs[name] = _transform(program, *passes_dict[name])
15241549
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
15251550
new_programs[name].graph_module
15261551
)
15271552
else:
15281553
new_programs[name] = copy.deepcopy(program)
15291554

1530-
else: # apply passes to every method
1555+
if isinstance(passes_or_pass_manager, Sequence):
1556+
# Transform all methods with the provided passes.
1557+
passes_seq : Sequence[PassType] = passes_or_pass_manager
1558+
for name, program in self._edge_programs.items():
1559+
new_programs[name] = _transform(program, *passes_seq)
1560+
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1561+
new_programs[name].graph_module
1562+
)
1563+
1564+
if isinstance(passes_or_pass_manager, PassManager):
1565+
# Transform all methods with the provided PassManager.
1566+
pass_manager : PassManager = passes_or_pass_manager
15311567
for name, program in self._edge_programs.items():
1532-
new_programs[name] = _transform(program, *passes)
1568+
new_programs[name] = _transform_with_pass_manager(program, pass_manager)
15331569
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
15341570
new_programs[name].graph_module
15351571
)

0 commit comments

Comments
 (0)