Skip to content

Commit ab56fd0

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Extend tranform() method params with PassManager. (#13140)
Summary: Pull Request resolved: #13140 There is a user request to pass a custom PassManager to the EdgeProgramManager::transform(). This diff extends method API appropriately. Reviewed By: larryliu0820 Differential Revision: D79292048
1 parent 72ef7b1 commit ab56fd0

File tree

2 files changed

+99
-29
lines changed

2 files changed

+99
-29
lines changed

exir/program/_program.py

Lines changed: 71 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,27 @@ def _transform(
240240
isinstance(p, (list, Verifier)) for p in passes
241241
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
242242

243-
pm = PassManager(list(passes))
244-
res = pm(self.graph_module)
243+
return _transform_with_pass_manager(self, PassManager(list(passes)), override_verifiers)
244+
245+
246+
def _transform_with_pass_manager(
247+
self,
248+
pass_manager: PassManager,
249+
override_verifiers: None | list[Type[Verifier]] = None,
250+
) -> "ExportedProgram":
251+
"""
252+
Transforms the program using the provided pass_manager.
253+
254+
Args:
255+
self: The ExportedProgram instance to transform
256+
pass_manager: An instance of PassManager to apply transformations.
257+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
258+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
259+
260+
Returns:
261+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
262+
"""
263+
res = pass_manager(self.graph_module)
245264
transformed_gm = res.graph_module if res is not None else self.graph_module
246265
assert transformed_gm is not None
247266

@@ -1230,7 +1249,7 @@ def collect_named_data_store_outputs(
12301249
def to_edge_transform_and_lower( # noqa: C901
12311250
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
12321251
transform_passes: Optional[
1233-
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
1252+
Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager]
12341253
] = None,
12351254
partitioner: Optional[
12361255
Union[List[Partitioner], Dict[str, List[Partitioner]]]
@@ -1259,11 +1278,15 @@ def to_edge_transform_and_lower( # noqa: C901
12591278
to their corresponding ExportedPrograms. If only a single ExportedProgram is
12601279
provided it will be assigned the name "forward".
12611280
1262-
transform_passes: The passes can either be a list of passes, or a dictionary
1263-
mapping method names to lists of passes. If it is just a list of passes, all methods
1264-
in the given EdgeProgramManager will be transformed with the provided passes. If it
1265-
is a dictionary, only method names specified in the dictionary will be transformed
1266-
with their corresponding passes.
1281+
transform_passes: The transform_passes can be one of:
1282+
1) a list of passes -
1283+
all methods in the given EdgeProgramManager will be transformed with the provided passes.
1284+
2) a dictionary -
1285+
only method names specified in the dictionary will be transformed
1286+
with their corresponding passes
1287+
3) an instance of a PassManager -
1288+
all methods in the given EdgeProgramManager will be
1289+
transformed with the given PassManager instance.
12671290
12681291
partitioner: The partitioner can either be a Partitioner subclass instance, or a
12691292
dictionary mapping method names to Partitioner subclass instance. If it is a
@@ -1493,19 +1516,23 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
14931516
@et_logger("transform")
14941517
def transform(
14951518
self,
1496-
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
1519+
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager],
14971520
compile_config: Optional[EdgeCompileConfig] = None,
14981521
) -> "EdgeProgramManager":
14991522
"""
15001523
Transforms the program according to the provided passes.
15011524
15021525
Args:
1503-
passes: The passes can either be a list of passes, or a
1504-
dictionary mapping method names to lists of passes. If it is
1505-
just a list of passes, all methods in the given EdgeProgramManager
1506-
will be transformed with the provided passes. If it is a
1507-
dictionary, only method names specified in the dictionary will be
1508-
transformed with their corresponding passes.
1526+
passes: This param can be one of:
1527+
1) a list of passes -
1528+
all methods in the given EdgeProgramManager
1529+
will be transformed with the provided passes.
1530+
2) a dictionary mapping method names to lists of passes -
1531+
only method names specified in the dictionary will be
1532+
transformed with their corresponding passes.
1533+
3) a PassManager instance -
1534+
all methods in the given EdgeProgramManager will be
1535+
transformed with the given PassManager instance.
15091536
compile_config: Compile config to use for veriy the correctness of model
15101537
graph after each pass. If not specified, the compile config of the
15111538
calling EdgeProgramManager will be used. It will be used in as compile
@@ -1515,24 +1542,39 @@ def transform(
15151542
EdgeProgramManager: A copy of the calling EdgeProgramManager with the
15161543
transformations applied.
15171544
"""
1545+
15181546
compile_config = compile_config or self.compile_config
15191547
new_programs: Dict[str, ExportedProgram] = {}
1548+
1549+
# Cast passes parameter upfront.
1550+
passes_seq: Optional[Sequence[PassType]] = None
1551+
passes_dict: Optional[Dict[str, Sequence[PassType]]] = None
1552+
pass_manager: Optional[PassManager] = None
1553+
1554+
if isinstance(passes, Sequence):
1555+
passes_seq = passes
15201556
if isinstance(passes, dict):
1521-
for name, program in self._edge_programs.items():
1522-
if name in passes.keys():
1523-
new_programs[name] = _transform(program, *passes[name])
1524-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1525-
new_programs[name].graph_module
1526-
)
1527-
else:
1528-
new_programs[name] = copy.deepcopy(program)
1557+
passes_dict = passes
1558+
if isinstance(passes, PassManager):
1559+
pass_manager = passes
15291560

1530-
else: # apply passes to every method
1531-
for name, program in self._edge_programs.items():
1532-
new_programs[name] = _transform(program, *passes)
1533-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1534-
new_programs[name].graph_module
1535-
)
1561+
for name, program in self._edge_programs.items():
1562+
# If the method name is enforced, but not matched, we skip transformation.
1563+
if isinstance(passes, dict) and passes_dict and name not in passes_dict.keys():
1564+
new_programs[name] = copy.deepcopy(program)
1565+
1566+
# Depending on the passes parameter, call the corresponding transform function.
1567+
if isinstance(passes, Sequence) and passes_seq:
1568+
new_programs[name] = _transform(program, *passes_seq)
1569+
if isinstance(passes, dict) and passes_dict:
1570+
new_programs[name] = _transform(program, *passes_dict[name])
1571+
if isinstance(passes, PassManager) and pass_manager:
1572+
new_programs[name] = _transform_with_pass_manager(program, pass_manager)
1573+
1574+
# Verify the correctness of model graph after each transformation.
1575+
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1576+
new_programs[name].graph_module
1577+
)
15361578

15371579
return EdgeProgramManager(
15381580
new_programs, copy.deepcopy(self._config_methods), compile_config

exir/program/test/test_program.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
AddMulPartitionerDemo,
1717
NonDecompTestPartitioner,
1818
)
19+
from torch.fx.passes.infra.pass_manager import PassManager
1920
from executorch.exir.dialects._ops import ops as exir_ops
2021
from executorch.exir.error import ExportError, InternalError
2122
from executorch.exir.lowered_backend_module import get_lowered_submodules
@@ -470,6 +471,33 @@ def test_transform_dict_api(self):
470471
torch.ones(1) + 1, # x + 1
471472
)
472473

474+
def test_transform_pass_manager_api(self):
475+
edge_manager = to_edge(get_exported_programs(), get_config_methods())
476+
477+
pm = PassManager()
478+
pm.add_pass(AddToMulPassEdge())
479+
480+
transformed_edge = edge_manager.transform(pm)
481+
482+
x = torch.ones(1) * 2
483+
y = torch.ones(1) * 3
484+
485+
# x * y + x -> x * y * x
486+
self.assertEqual(
487+
transformed_edge.exported_program("forward").module()(
488+
x, y
489+
),
490+
x * y * x
491+
)
492+
493+
# x + 1 -> x * 1
494+
self.assertEqual(
495+
transformed_edge.exported_program("foo").module()(
496+
x,
497+
),
498+
x * 1
499+
)
500+
473501
def test_edge_to_backend_replaces_subgraph(self):
474502
edge_manager: EdgeProgramManager = to_edge(
475503
get_exported_programs(), get_config_methods()

0 commit comments

Comments
 (0)