Skip to content

Commit 77e7ad1

Browse files
Erik-Lundellfacebook-github-bot
authored andcommitted
Add access to edge_program in ArmPassManager (pytorch#5542)
Summary: Change-Id: Iecb1af0001dee7c48961129f0d49644cddfc18a0 Pull Request resolved: pytorch#5542 Reviewed By: cccclai Differential Revision: D63262789 Pulled By: digantdesai fbshipit-source-id: 8219a97aeb3ee25e318724297a39f86987557600
1 parent 905b88c commit 77e7ad1

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def preprocess( # noqa: C901
217217
# const data directly. Path created and data written only in debug builds.
218218
tosa_graph = ts.TosaSerializer(artifact_path)
219219
graph_module = ArmPassManager().transform_to_backend_pipeline(
220-
graph_module=edge_program.graph_module, compile_spec=compile_spec
220+
exported_program=edge_program, compile_spec=compile_spec
221221
)
222222

223223
node_visitors = get_node_visitors(edge_program)

backends/arm/passes/arm_pass_manager.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass
2424
from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass
25+
from executorch.exir import ExportedProgram
2526
from executorch.exir.backend.compile_spec_schema import CompileSpec
2627
from executorch.exir.pass_manager import PassManager
2728

@@ -32,7 +33,7 @@ def _transform(self, graph_module: torch.fx.GraphModule):
3233
return self(graph_module).graph_module
3334

3435
def transform_to_backend_pipeline(
35-
self, graph_module: torch.fx.GraphModule, compile_spec: list[CompileSpec]
36+
self, exported_program: ExportedProgram, compile_spec: list[CompileSpec]
3637
):
3738
"""Apply passes before transforming program to backend"""
3839
self.add_pass(SizeAdjustConv2DPass())
@@ -46,4 +47,4 @@ def transform_to_backend_pipeline(
4647
if memory_format == "nhwc":
4748
self.add_pass(AnnotateChannelsLastDimOrder())
4849

49-
return self._transform(graph_module)
50+
return self._transform(exported_program.graph_module)

0 commit comments

Comments
 (0)