Skip to content

Commit afdbb85

Browse files
authored
Extend EdgeProgramManager::transform() method declaration with PassManager as a "passes" param variant.
Differential Revision: D79292048 Pull Request resolved: #13140
1 parent c52f6a0 commit afdbb85

File tree

2 files changed

+103
-29
lines changed

2 files changed

+103
-29
lines changed

exir/program/_program.py

Lines changed: 78 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,29 @@ def _transform(
240240
isinstance(p, (list, Verifier)) for p in passes
241241
), f"Expected all passes to be of PassType, not list or Verifier. Use override_verifiers kwarg instead. Got: {list(passes)}"
242242

243-
pm = PassManager(list(passes))
244-
res = pm(self.graph_module)
243+
return _transform_with_pass_manager(
244+
self, PassManager(list(passes)), override_verifiers
245+
)
246+
247+
248+
def _transform_with_pass_manager(
249+
self,
250+
pass_manager: PassManager,
251+
override_verifiers: None | list[Type[Verifier]] = None,
252+
) -> "ExportedProgram":
253+
"""
254+
Transforms the program using the provided pass_manager.
255+
256+
Args:
257+
self: The ExportedProgram instance to transform
258+
pass_manager: An instance of PassManager to apply transformations.
259+
override_verifiers: Optional list of verifier classes to use instead of the default verifiers.
260+
This is needed if the transforms yields illegal graph that the default verifier cannot handle.
261+
262+
Returns:
263+
ExportedProgram: A new ExportedProgram with the transformations applied, or self if no changes were made
264+
"""
265+
res = pass_manager(self.graph_module)
245266
transformed_gm = res.graph_module if res is not None else self.graph_module
246267
assert transformed_gm is not None
247268

@@ -1230,7 +1251,7 @@ def collect_named_data_store_outputs(
12301251
def to_edge_transform_and_lower( # noqa: C901
12311252
programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
12321253
transform_passes: Optional[
1233-
Union[Sequence[PassType], Dict[str, Sequence[PassType]]]
1254+
Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager]
12341255
] = None,
12351256
partitioner: Optional[
12361257
Union[List[Partitioner], Dict[str, List[Partitioner]]]
@@ -1259,11 +1280,15 @@ def to_edge_transform_and_lower( # noqa: C901
12591280
to their corresponding ExportedPrograms. If only a single ExportedProgram is
12601281
provided it will be assigned the name "forward".
12611282
1262-
transform_passes: The passes can either be a list of passes, or a dictionary
1263-
mapping method names to lists of passes. If it is just a list of passes, all methods
1264-
in the given EdgeProgramManager will be transformed with the provided passes. If it
1265-
is a dictionary, only method names specified in the dictionary will be transformed
1266-
with their corresponding passes.
1283+
transform_passes: The transform_passes can be one of:
1284+
1) a list of passes -
1285+
all methods in the given EdgeProgramManager will be transformed with the provided passes.
1286+
2) a dictionary -
1287+
only method names specified in the dictionary will be transformed
1288+
with their corresponding passes
1289+
3) an instance of a PassManager -
1290+
all methods in the given EdgeProgramManager will be
1291+
transformed with the given PassManager instance.
12671292
12681293
partitioner: The partitioner can either be a Partitioner subclass instance, or a
12691294
dictionary mapping method names to Partitioner subclass instance. If it is a
@@ -1493,19 +1518,23 @@ def exported_program(self, method_name: str = "forward") -> ExportedProgram:
14931518
@et_logger("transform")
14941519
def transform(
14951520
self,
1496-
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]]],
1521+
passes: Union[Sequence[PassType], Dict[str, Sequence[PassType]], PassManager],
14971522
compile_config: Optional[EdgeCompileConfig] = None,
14981523
) -> "EdgeProgramManager":
14991524
"""
15001525
Transforms the program according to the provided passes.
15011526
15021527
Args:
1503-
passes: The passes can either be a list of passes, or a
1504-
dictionary mapping method names to lists of passes. If it is
1505-
just a list of passes, all methods in the given EdgeProgramManager
1506-
will be transformed with the provided passes. If it is a
1507-
dictionary, only method names specified in the dictionary will be
1508-
transformed with their corresponding passes.
1528+
passes: This param can be one of:
1529+
1) a list of passes -
1530+
all methods in the given EdgeProgramManager
1531+
will be transformed with the provided passes.
1532+
2) a dictionary mapping method names to lists of passes -
1533+
only method names specified in the dictionary will be
1534+
transformed with their corresponding passes.
1535+
3) a PassManager instance -
1536+
all methods in the given EdgeProgramManager will be
1537+
transformed with the given PassManager instance.
15091538
compile_config: Compile config to use for veriy the correctness of model
15101539
graph after each pass. If not specified, the compile config of the
15111540
calling EdgeProgramManager will be used. It will be used in as compile
@@ -1515,24 +1544,44 @@ def transform(
15151544
EdgeProgramManager: A copy of the calling EdgeProgramManager with the
15161545
transformations applied.
15171546
"""
1547+
15181548
compile_config = compile_config or self.compile_config
15191549
new_programs: Dict[str, ExportedProgram] = {}
1550+
1551+
# Cast passes parameter upfront.
1552+
passes_seq: Optional[Sequence[PassType]] = None
1553+
passes_dict: Optional[Dict[str, Sequence[PassType]]] = None
1554+
pass_manager: Optional[PassManager] = None
1555+
1556+
if isinstance(passes, Sequence):
1557+
passes_seq = passes
15201558
if isinstance(passes, dict):
1521-
for name, program in self._edge_programs.items():
1522-
if name in passes.keys():
1523-
new_programs[name] = _transform(program, *passes[name])
1524-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1525-
new_programs[name].graph_module
1526-
)
1527-
else:
1528-
new_programs[name] = copy.deepcopy(program)
1559+
passes_dict = passes
1560+
if isinstance(passes, PassManager):
1561+
pass_manager = passes
15291562

1530-
else: # apply passes to every method
1531-
for name, program in self._edge_programs.items():
1532-
new_programs[name] = _transform(program, *passes)
1533-
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1534-
new_programs[name].graph_module
1535-
)
1563+
for name, program in self._edge_programs.items():
1564+
# If the method name is enforced, but not matched, we skip transformation.
1565+
if (
1566+
isinstance(passes, dict)
1567+
and passes_dict
1568+
and name not in passes_dict.keys()
1569+
):
1570+
new_programs[name] = copy.deepcopy(program)
1571+
continue
1572+
1573+
# Depending on the passes parameter, call the corresponding transform function.
1574+
if passes_seq is not None:
1575+
new_programs[name] = _transform(program, *passes_seq)
1576+
elif passes_dict is not None:
1577+
new_programs[name] = _transform(program, *passes_dict[name])
1578+
elif pass_manager is not None:
1579+
new_programs[name] = _transform_with_pass_manager(program, pass_manager)
1580+
1581+
# Verify the correctness of model graph after each transformation.
1582+
EXIREdgeDialectVerifier(edge_compile_config=compile_config)(
1583+
new_programs[name].graph_module
1584+
)
15361585

15371586
return EdgeProgramManager(
15381587
new_programs, copy.deepcopy(self._config_methods), compile_config

exir/program/test/test_program.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from torch._export.verifier import Verifier
3838
from torch.export import Dim, export, ExportedProgram
3939
from torch.export._trace import _export
40+
from torch.fx.passes.infra.pass_manager import PassManager
4041

4142
from torch.library import impl, Library
4243
from torch.nn import functional as F
@@ -470,6 +471,30 @@ def test_transform_dict_api(self):
470471
torch.ones(1) + 1, # x + 1
471472
)
472473

474+
def test_transform_pass_manager_api(self):
475+
edge_manager = to_edge(get_exported_programs(), get_config_methods())
476+
477+
pm = PassManager()
478+
pm.add_pass(AddToMulPassEdge())
479+
480+
transformed_edge = edge_manager.transform(pm)
481+
482+
x = torch.ones(1) * 2
483+
y = torch.ones(1) * 3
484+
485+
# x * y + x -> x * y * x
486+
self.assertEqual(
487+
transformed_edge.exported_program("forward").module()(x, y), x * y * x
488+
)
489+
490+
# x + 1 -> x * 1
491+
self.assertEqual(
492+
transformed_edge.exported_program("foo").module()(
493+
x,
494+
),
495+
x * 1,
496+
)
497+
473498
def test_edge_to_backend_replaces_subgraph(self):
474499
edge_manager: EdgeProgramManager = to_edge(
475500
get_exported_programs(), get_config_methods()

0 commit comments

Comments
 (0)