Skip to content

Commit ce06d16

Browse files
authored
Merge branch 'pytorch:main' into Arm-backend-Make-run.sh-handle-portable-ops-using-aten<OP>.<modifier>out
2 parents ac4ac12 + dcacde0 commit ce06d16

File tree

26 files changed

+584
-62
lines changed

26 files changed

+584
-62
lines changed

.github/scripts/check_labels.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ def main() -> None:
4545

4646
try:
4747
if not has_required_labels(pr):
48-
print(LABEL_ERR_MSG)
48+
print(LABEL_ERR_MSG, flush=True)
4949
add_label_err_comment(pr)
5050
if args.exit_non_zero:
51-
sys.exit(1)
51+
raise RuntimeError("PR does not have required labels")
5252
else:
5353
delete_all_label_err_comments(pr)
5454
except Exception as e:
5555
if args.exit_non_zero:
56-
sys.exit(1)
56+
raise RuntimeError(f"Error checking labels: {e}") from e
5757

5858
sys.exit(0)
5959

.github/scripts/github_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def gh_fetch_url(
7272
headers: Optional[Dict[str, str]] = None,
7373
data: Union[Optional[Dict[str, Any]], str] = None,
7474
method: Optional[str] = None,
75-
reader: Callable[[Any], Any] = lambda x: x.read(),
75+
reader: Callable[[Any], Any] = json.load,
7676
) -> Any:
7777
return gh_fetch_url_and_headers(
78-
url, headers=headers, data=data, reader=json.load, method=method
78+
url, headers=headers, data=data, reader=reader, method=method
7979
)[1]
8080

8181

@@ -169,7 +169,7 @@ def gh_post_commit_comment(
169169

170170
def gh_delete_comment(org: str, repo: str, comment_id: int) -> None:
171171
url = f"{GITHUB_API_URL}/repos/{org}/{repo}/issues/comments/{comment_id}"
172-
gh_fetch_url(url, method="DELETE")
172+
gh_fetch_url(url, method="DELETE", reader=lambda x: x.read())
173173

174174

175175
def gh_fetch_merge_base(org: str, repo: str, base: str, head: str) -> str:

.github/workflows/android-perf.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ jobs:
136136
fail-fast: false
137137
with:
138138
runner: linux.4xlarge
139-
docker-image: executorch-ubuntu-22.04-clang12-android
139+
docker-image: executorch-ubuntu-22.04-qnn-sdk
140140
submodules: 'true'
141141
timeout: 60
142142
upload-artifact: android-models

.github/workflows/trunk.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ jobs:
302302
fail-fast: false
303303
with:
304304
runner: linux.2xlarge
305-
docker-image: executorch-ubuntu-22.04-clang12-android
305+
docker-image: executorch-ubuntu-22.04-qnn-sdk
306306
submodules: 'true'
307307
ref: ${{ github.event_name == 'pull_request' && github.event.pull_request.head.sha || github.sha }}
308308
timeout: 900

.gitmodules

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
[submodule "backends/arm/third-party/ethos-u-core-driver"]
22
path = backends/arm/third-party/ethos-u-core-driver
3-
url = https://review.mlplatform.org/ml/ethos-u/ethos-u-core-driver
3+
url = https://git.mlplatform.org/ml/ethos-u/ethos-u-core-driver.git/
44
[submodule "backends/arm/third-party/serialization_lib"]
55
path = backends/arm/third-party/serialization_lib
6-
url = https://review.mlplatform.org/tosa/serialization_lib
6+
url = https://git.mlplatform.org/tosa/serialization_lib.git/
77
[submodule "backends/vulkan/third-party/Vulkan-Headers"]
88
path = backends/vulkan/third-party/Vulkan-Headers
99
url = https://github.com/KhronosGroup/Vulkan-Headers

backends/arm/operators/op_add.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def define_node(
8282

8383
if needs_rescale:
8484
# Scale output back to 8 bit
85+
# pyre-ignore
8586
tqutils.rescale_node_back_to_int8(node, add_output, scale, tosa_graph)
8687

8788

backends/cadence/aot/TARGETS

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ python_library(
3838
deps = [
3939
":passes",
4040
":utils",
41+
":ops_registrations",
4142
"//caffe2:torch",
4243
"//executorch/backends/cadence/aot/quantizer:fusion_pass",
4344
"//executorch/backends/cadence/aot/quantizer:quantizer",
@@ -71,6 +72,8 @@ python_library(
7172
],
7273
deps = [
7374
":utils",
75+
":fuse_ops",
76+
":simplify_ops",
7477
"//caffe2:torch",
7578
"//executorch/exir:pass_base",
7679
"//executorch/exir/dialects:lib",
@@ -132,6 +135,18 @@ python_library(
132135
],
133136
)
134137

138+
python_library(
139+
name = "graph_builder",
140+
srcs = [
141+
"graph_builder.py",
142+
],
143+
typing = True,
144+
deps = [
145+
"fbcode//caffe2:torch",
146+
"fbcode//executorch/exir:pass_base",
147+
],
148+
)
149+
135150
python_library(
136151
name = "fuse_ops",
137152
srcs = [
@@ -150,3 +165,34 @@ python_library(
150165
"//executorch/exir/passes:spec_prop_pass",
151166
],
152167
)
168+
169+
python_library(
170+
name = "simplify_ops",
171+
srcs = [
172+
"simplify_ops.py",
173+
],
174+
typing = True,
175+
deps = [
176+
":pass_utils",
177+
"//executorch/backends/cadence/aot:pass_utils",
178+
"//executorch/exir:pass_base",
179+
"//executorch/exir/dialects:lib",
180+
],
181+
)
182+
183+
python_unittest(
184+
name = "test_graph_builder",
185+
srcs = [
186+
"tests/test_graph_builder.py",
187+
],
188+
typing = True,
189+
deps = [
190+
"//caffe2:torch",
191+
"//executorch/backends/cadence/aot:graph_builder",
192+
"//executorch/backends/cadence/aot:pass_utils",
193+
"//executorch/exir:pass_base",
194+
"//executorch/exir/dialects:lib",
195+
"//later:lib",
196+
":ops_registrations"
197+
],
198+
)

backends/cadence/aot/compiler.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from pathlib import Path
1111
from typing import Callable, cast, Optional
1212

13+
import executorch.backends.cadence.aot.ops_registrations # noqa
1314
import torch
1415

1516
from executorch.backends.cadence.aot.passes import ReplaceSafeSoftmaxWithSoftmax
@@ -196,7 +197,26 @@ def export_to_edge(
196197
# Export the model and lower it to an EdgeProgramManager (in edge IR), and
197198
# apply passes specific to Cadence DSP execution. Return both to print the
198199
# differences.
199-
def export_to_cadence_edge_executorch(
200+
def export_to_cadence(
201+
model: torch.nn.Module,
202+
inputs: tuple[object, ...],
203+
dump_graphs: bool = False,
204+
output_dir: Optional[str] = None,
205+
opt_level: int = 1,
206+
) -> EdgeProgramManager:
207+
edge_prog_manager = export_to_edge(model, inputs)
208+
cadence_passes = get_cadence_passes(opt_level)
209+
210+
# Run a couple required passes for quant/dequant ops
211+
cadence_prog_manager = edge_prog_manager.transform(
212+
cast(
213+
list[Callable[[torch.fx.GraphModule], Optional[PassResult]]], cadence_passes
214+
)
215+
)
216+
return cadence_prog_manager
217+
218+
219+
def export_to_executorch_gen_etrecord(
200220
model: torch.nn.Module,
201221
inputs: tuple[object, ...],
202222
dump_graphs: bool = False,

backends/cadence/aot/export_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from executorch.backends.cadence.aot.compiler import (
1818
convert_pt2,
19-
export_to_cadence_edge_executorch,
19+
export_to_executorch_gen_etrecord,
2020
fuse_pt2,
2121
)
2222

@@ -86,8 +86,8 @@ def export_model(
8686
quantized_model = fuse_pt2(converted_model, quantizer)
8787

8888
# Get edge program after Cadence specific passes
89-
exec_prog: ExecutorchProgramManager = export_to_cadence_edge_executorch(
90-
quantized_model, example_inputs, working_dir
89+
exec_prog: ExecutorchProgramManager = export_to_executorch_gen_etrecord(
90+
quantized_model, example_inputs, output_dir=working_dir
9191
)
9292

9393
logging.info("Final exported graph:\n")

backends/cadence/aot/fuse_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1022,7 +1022,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
10221022
return PassResult(graph_module, True)
10231023

10241024

1025-
class FuseOpsInGraph:
1025+
class CadenceFuseOpsInGraph:
10261026
passes = [
10271027
FuseMMWithAdd,
10281028
FuseBatchNormWithConv,

0 commit comments

Comments
 (0)