Skip to content

Commit 18b829c

Browse files
authored
Replace custom op pad with aten op, post-export
Differential Revision: D60941693 Pull Request resolved: #4603
1 parent ce7f5a0 commit 18b829c

File tree

4 files changed

+83
-0
lines changed

4 files changed

+83
-0
lines changed

examples/models/flamingo/passes/__init__.py

Whitespace-only changes.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import torch
10+
from executorch.exir.pass_base import ExportPass
11+
from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa
12+
13+
14+
class ReplaceCustomOpsWithAtenOpsPass(ExportPass):
15+
"""
16+
Goes through all ops and replaces custom ops with aten ops. In some cases
17+
aten ops cannot be exported due to dynamism, eg. pad in flamingo preprocess.
18+
Use a custom op to pass export, and replace it with the aten op post-export,
19+
which avoids re-writing the op in C++.
20+
"""
21+
22+
def __init__(self) -> None:
23+
super().__init__()
24+
25+
def call_operator(self, op, args, kwargs, meta):
26+
if op._name == "preprocess::pad":
27+
return super().call_operator(
28+
torch.ops.aten.constant_pad_nd.default, args, kwargs, meta
29+
)
30+
31+
return super().call_operator(op, args, kwargs, meta)
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-unsafe
8+
9+
import unittest
10+
11+
from typing import List
12+
13+
import torch
14+
from executorch.exir import EdgeCompileConfig, to_edge
15+
16+
from .replace_custom_ops_with_aten_ops_pass import ReplaceCustomOpsWithAtenOpsPass
17+
18+
19+
class TestPasses(unittest.TestCase):
20+
def test_replace_custom_ops_with_aten_ops_pass(self) -> None:
21+
from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa
22+
23+
class Pad(torch.nn.Module):
24+
def forward(self, x: torch.Tensor, padding: List[int]) -> torch.Tensor:
25+
return torch.ops.preprocess.pad.default(x, padding)
26+
27+
pad = Pad()
28+
29+
image_tensor = torch.ones([3, 4, 5])
30+
padding = [0, 2, 0, 1]
31+
32+
edge_prog = to_edge(
33+
torch.export.export(pad, (image_tensor, padding), strict=False),
34+
compile_config=EdgeCompileConfig(_check_ir_validity=False),
35+
)
36+
37+
# Check that the custom op exists in the graph, and aten op does not.
38+
edge_nodes = [node.name for node in edge_prog.exported_program().graph.nodes]
39+
assert "constant_pad_nd" not in edge_nodes
40+
assert "preprocess_pad_default" in edge_nodes
41+
42+
edge_prog = edge_prog.transform([ReplaceCustomOpsWithAtenOpsPass()])
43+
44+
# After running replace_custom_ops_with_aten_ops pass, the custom op
45+
# should be replaced with aten op.
46+
post_transform_nodes = [
47+
node.name for node in edge_prog.exported_program().graph.nodes
48+
]
49+
assert "constant_pad_nd" in post_transform_nodes
50+
assert "preprocess_pad_default" not in post_transform_nodes

exir/passes/replace_aten_with_edge_pass.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
# pyre-unsafe
8+
79
import torch
810
from executorch.exir.dialects._ops import ops
911
from executorch.exir.dialects.edge._ops import EdgeOpOverload

0 commit comments

Comments
 (0)