Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,26 @@ python_library(
],
)

python_library(
name = "export_example",
srcs = [
"export_example.py",
],
deps = [
":passes",
":utils",
":ops_registrations",
":replace_ops",
"//caffe2:torch",
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
"//executorch/backends/cadence/runtime:runtime",
"//executorch/backends/cadence/aot/quantizer:quantizer",
"//executorch/backends/transforms:decompose_sdpa",
"//executorch/backends/transforms:remove_clone_ops",
"//executorch/exir:lib",
"//executorch/devtools:lib",
],
)

python_library(
name = "pass_utils",
Expand Down
14 changes: 8 additions & 6 deletions backends/cadence/aot/export_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def export_model(
model: nn.Module,
example_inputs: Tuple[Any, ...],
file_name: str = "CadenceDemoModel",
run_and_compare: bool = True,
):
# create work directory for outputs and model binary
working_dir = tempfile.mkdtemp(dir="/tmp")
Expand Down Expand Up @@ -112,9 +113,10 @@ def export_model(
)

# TODO: move to test infra
runtime.run_and_compare(
executorch_prog=exec_prog,
inputs=example_inputs,
ref_outputs=ref_outputs,
working_dir=working_dir,
)
if run_and_compare:
runtime.run_and_compare(
executorch_prog=exec_prog,
inputs=example_inputs,
ref_outputs=ref_outputs,
working_dir=working_dir,
)
3 changes: 2 additions & 1 deletion backends/cadence/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def print_ops_info(

# Print the final ops and their counts in a tabular format
logging.info(
tabulate(
"\n"
+ tabulate(
sorted_ops_count,
headers=[
"Final Operators ", # one character longer than the longest op name
Expand Down
2 changes: 2 additions & 0 deletions backends/cadence/runtime/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ python_library(
srcs = [
"__init__.py",
"executor.py",
"runtime.py",
"utils.py"
] + glob([
"xtsc-cfg/**/*",
]),
Expand Down
26 changes: 26 additions & 0 deletions examples/cadence/operators/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

load("@fbcode_macros//build_defs:python_unittest.bzl", "python_unittest")

oncall("odai_jarvis")


python_unittest(
name = "test_add_op",
srcs = [
"test_add_op.py",
],
typing = True,
supports_static_listing = False,
deps = [
"fbsource//third-party/pypi/parameterized:parameterized",
"//caffe2:torch",
"//executorch/backends/cadence/aot:ops_registrations",
"//executorch/backends/cadence/aot:export_example",
"//executorch/backends/cadence/aot:compiler",
],
)
115 changes: 115 additions & 0 deletions examples/cadence/operators/test_add_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import unittest
from typing import Tuple

from parameterized import parameterized

from executorch.backends.cadence.aot.ops_registrations import * # noqa

import torch
import torch.nn as nn
from executorch.backends.cadence.aot.export_example import export_model


class ATenOpTestCases(unittest.TestCase):
@parameterized.expand(
[
[(7, 5, 6), (7, 5, 6)],
[(7, 5, 6), (1)],
[(1), (7, 5, 6)],
[(1), (7, 5, 6), 2.23],
[(1), (7, 5, 6), -1.0],
[(1), (7, 5, 6), -2.23],
[(7, 5, 6), (7, 5, 6), 1.23],
[(6, 7), (6, 7)],
[(6, 7), (6, 7), 2],
# Broadcast tests (should be optimized on G3)
[(1, 32, 64), (1, 1, 64)],
[(1, 32, 64), (64)],
[(1, 1, 32), (32)],
[(16, 1, 16), (1, 1, 16)],
[(16, 1, 16), (16)],
[(1, 4, 8, 8), (1, 1, 8, 8)],
[(1, 4, 8, 8), (8, 8)],
# Broadcast tests (should go to portable ops)
[(1, 10, 1, 8), (4, 1, 4, 1)],
[(1, 1, 16), (1, 8, 1), 2.5],
# # aten.upsample_nearest2d tests
[(5, 6, 6, 8), (5, 6, 6, 8)],
[(1, 1, 12, 16), (1, 1, 12, 16)],
]
)
def test_aten_add_out(
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
) -> None:
class AddTensor(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: torch.Tensor):
return torch.add(x, y, alpha=self.alpha)

model = AddTensor(alpha)

X = torch.randn(Xshape)
Y = torch.randn(Yshape)

model.eval()
export_model(
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
)

@parameterized.expand(
[
[(7, 5, 6), (7, 5, 6)],
[(7, 5, 6), (1)],
[(1), (7, 5, 6)],
[(1), (7, 5, 6), 2.23],
[(1), (7, 5, 6), -1.0],
[(1), (7, 5, 6), -2.23],
[(7, 5, 6), (7, 5, 6), 1.23],
[(6, 7), (6, 7)],
[(6, 7), (6, 7), 2],
# Broadcast tests (should be optimized on G3)
[(1, 32, 64), (1, 1, 64)],
[(1, 32, 64), (64)],
[(1, 1, 32), (32)],
[(16, 1, 16), (1, 1, 16)],
[(16, 1, 16), (16)],
[(1, 4, 8, 8), (1, 1, 8, 8)],
[(1, 4, 8, 8), (8, 8)],
# Broadcast tests (should go to portable ops)
[(1, 10, 1, 8), (4, 1, 4, 1)],
[(1, 1, 16), (1, 8, 1), 2.5],
# # aten.upsample_nearest2d tests
[(5, 6, 6, 8), (5, 6, 6, 8)],
[(1, 1, 12, 16), (1, 1, 12, 16)],
]
)
def test_aten_add_scalar_out(
self, Xshape: Tuple[int], Yshape: Tuple[int], alpha: float = 1
) -> None:
# Tensor-Scalar addition
class AddScalar(nn.Module):
def __init__(self, alpha: float):
super().__init__()
self.alpha = alpha

def forward(self, x: torch.Tensor, y: float):
return torch.add(x, y, alpha=self.alpha)

model = AddScalar(alpha)

X = torch.randn(Xshape)
Y = 2.34

model.eval()
export_model(
model, (X, Y), file_name=self._testMethodName, run_and_compare=False
)


if __name__ == "__main__":
unittest.main()
Loading