Skip to content

Commit 1d80837

Browse files
authored
Call ExportPass() inside ReplaceNopTransposeOrPermuteWithViewPass::call().
Differential Revision: D79212506 Pull Request resolved: #13005
1 parent d80dfa3 commit 1d80837

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

backends/cadence/aot/pass_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
1414

1515
from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket
16-
from executorch.exir.pass_base import PassBase
16+
from executorch.exir.pass_base import PassBase, PassResult
1717

1818
from torch._ops import OpOverloadPacket
1919

@@ -224,3 +224,8 @@ def set_arg(
224224
node.update_arg(idx, value)
225225
else:
226226
node.update_kwarg(kwarg_name, value)
227+
228+
229+
def none_throws(x: Optional[PassResult]) -> PassResult:
230+
assert x is not None
231+
return x

backends/cadence/aot/replace_ops.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
)
4040
from executorch.backends.cadence.aot.pass_utils import (
4141
CadencePassAttribute,
42+
none_throws,
4243
register_cadence_pass,
4344
)
4445
from executorch.backends.cadence.aot.remove_ops import RemoveNopSelectOpPass
@@ -1661,8 +1662,8 @@ def call_operator(self, op, args, kwargs, meta):
16611662

16621663
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
16631664
result = super().call(graph_module)
1664-
result = FuseCascadedViewOps()(result.graph_module)
1665-
assert result is not None
1665+
fuse_cascaded_result = none_throws(FuseCascadedViewOps()(result.graph_module))
1666+
result = none_throws(ExportPass()(fuse_cascaded_result.graph_module))
16661667
return result
16671668

16681669

0 commit comments

Comments
 (0)