Skip to content

Commit 45b441a

Browse files
committed
support latest backend interface
1 parent e14146c commit 45b441a

File tree

6 files changed

+43
-11
lines changed

6 files changed

+43
-11
lines changed

backends/aoti/aoti_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import copy
8+
import os
9+
import shutil
810

911
from subprocess import check_call
1012
from typing import final, List
@@ -16,8 +18,7 @@
1618
PreprocessResult,
1719
)
1820
from executorch.exir.backend.compile_spec_schema import CompileSpec
19-
import os
20-
import shutil
21+
2122

2223
@final
2324
class AotiBackend(BackendDetails):
@@ -33,7 +34,7 @@ def preprocess(
3334
graph_module = copy_edge_program.graph_module
3435
args, kwargs = copy_edge_program.example_inputs
3536
temp_so_path = torch._inductor.aot_compile(graph_module, args, kwargs, options={}) # type: ignore[arg-type]
36-
so_path = os.path.join(os.getcwd(), 'aoti.so')
37+
so_path = os.path.join(os.getcwd(), "aoti.so")
3738
print("so_path after aot_compile: ", temp_so_path)
3839
print("so path we will using ", so_path)
3940
shutil.copyfile(temp_so_path, so_path)

backends/aoti/aoti_partitioner.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,15 @@
2323

2424
from torch.fx.passes.operator_support import OperatorSupportBase
2525

26+
supported_fallback_operators = []
27+
2628

2729
class AOTISupportedOperators(OperatorSupportBase):
2830
def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
29-
supported = node.op == "call_function" and node.target in [
30-
exir_ops.edge.aten.add.Tensor,
31-
exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
32-
]
31+
supported = (
32+
node.op == "call_function"
33+
and node.target not in supported_fallback_operators
34+
)
3335

3436
return supported
3537

backends/aoti/runtime/AotiBackend.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ using executorch::runtime::FreeableBuffer;
5050
using executorch::runtime::MemoryAllocator;
5151
using executorch::runtime::Result;
5252
using executorch::runtime::etensor::Tensor;
53+
using executorch::runtime::Span;
5354

5455
extern "C" {
5556
using AOTITensorHandle = Tensor*;
@@ -494,7 +495,7 @@ class AOTIBackend final : public ::executorch::runtime::BackendInterface {
494495
Error execute(
495496
BackendExecutionContext& context,
496497
DelegateHandle* handle_,
497-
EValue** args) const override {
498+
Span<EValue*> args) const override {
498499
AOTIDelegateHandle* handle = (AOTIDelegateHandle*)handle_;
499500

500501
size_t num_inputs;

export_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,5 @@ def forward(self, x: torch.Tensor, y: torch.Tensor):
2727
executorch_program = edge_program.to_executorch()
2828

2929
# 4. Save the compiled .pte program
30-
with open("add.pte", "wb") as file:
30+
with open("aoti_model.pte", "wb") as file:
3131
file.write(executorch_program.buffer)

export_and_run_aoti.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
./install_executorch.sh
2-
python export_add.py
2+
python $1
33
./install_executorch.sh --clean
44
mkdir -p cmake-out
55
cd cmake-out
@@ -8,4 +8,4 @@ cmake -DEXECUTORCH_BUILD_AOTI=ON \
88
..
99
cd ..
1010
cmake --build cmake-out -j9
11-
./cmake-out/executor_runner --model_path add.pte
11+
./cmake-out/executor_runner --model_path aoti_model.pte

export_mv2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
from executorch.backends.aoti.aoti_partitioner import AotiPartitioner
3+
from executorch.examples.models.mobilenet_v2 import MV2Model
4+
from executorch.exir import to_edge
5+
from torch.export import export
6+
from torchvision import models
7+
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights
8+
9+
mv2 = models.mobilenetv2.mobilenet_v2(weights=MobileNet_V2_Weights)
10+
mv2 = mv2.eval()
11+
12+
model_inputs = (torch.randn(1, 3, 224, 224),)
13+
14+
15+
# 1. torch.export: Defines the program with the ATen operator set.
16+
aten_dialect = export(mv2, model_inputs)
17+
18+
# 2. to_edge: Make optimizations for Edge devices
19+
edge_program = to_edge(aten_dialect)
20+
21+
edge_program = edge_program.to_backend(AotiPartitioner([]))
22+
23+
# 3. to_executorch: Convert the graph to an ExecuTorch program
24+
executorch_program = edge_program.to_executorch()
25+
26+
# 4. Save the compiled .pte program
27+
with open("aoti_model.pte", "wb") as file:
28+
file.write(executorch_program.buffer)

0 commit comments

Comments
 (0)