Skip to content

Commit a8c592e

Browse files
authored
Buckify backends/arm for meta internal use.
Differential Revision: D62062674 Pull Request resolved: #5023
1 parent cd1c833 commit a8c592e

21 files changed

+232
-42
lines changed

backends/arm/TARGETS

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "arm_partitioner",
5+
srcs = [
6+
"arm_partitioner.py",
7+
],
8+
typing = True,
9+
deps = [
10+
":arm_backend",
11+
"//executorch/backends/arm/passes:passes",
12+
"//executorch/exir:lib",
13+
],
14+
)
15+
16+
python_library(
17+
name = "arm_backend",
18+
srcs = [
19+
"arm_backend.py",
20+
],
21+
typing = True,
22+
deps = [
23+
"fbsource//third-party/pypi/flatbuffers:flatbuffers",
24+
"fbsource//third-party/pypi/ml-dtypes:ml-dtypes",
25+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
26+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
27+
":arm_vela",
28+
"//executorch/backends/arm/operators:lib",
29+
"//executorch/backends/arm/operators:node_visitor",
30+
"//executorch/backends/arm/passes:passes",
31+
],
32+
)
33+
34+
python_library(
35+
name = "arm_vela",
36+
srcs = [
37+
"arm_vela.py",
38+
],
39+
typing = True,
40+
deps = [
41+
"fbsource//third-party/pypi/ethos-u-vela:ethos-u-vela",
42+
],
43+
)
44+
45+
python_library(
46+
name = "tosa_mapping",
47+
srcs = [
48+
"tosa_mapping.py",
49+
],
50+
typing = True,
51+
deps = [
52+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
53+
"//caffe2:torch",
54+
],
55+
)
56+
57+
python_library(
58+
name = "tosa_quant_utils",
59+
srcs = [
60+
"tosa_quant_utils.py",
61+
],
62+
typing = True,
63+
deps = [
64+
"fbsource//third-party/pypi/numpy:numpy",
65+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
66+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
67+
":tosa_mapping",
68+
"//executorch/exir/dialects:lib",
69+
],
70+
)
71+
72+
python_library(
73+
name = "tosa_utils",
74+
srcs = [
75+
"tosa_utils.py",
76+
],
77+
typing = True,
78+
deps = [
79+
"fbsource//third-party/serialization_lib/python/serializer:serializer",
80+
":tosa_quant_utils",
81+
"//executorch/backends/arm/operators:node_visitor",
82+
],
83+
)

backends/arm/arm_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def is_tosa(compile_spec: List[CompileSpec]) -> bool:
159159
return False
160160

161161

162-
def get_intermediate_path(compile_spec: List[CompileSpec]) -> str:
162+
def get_intermediate_path(compile_spec: List[CompileSpec]) -> Optional[str]:
163163
for spec in compile_spec:
164164
if spec.key == "debug_artifact_path":
165165
return spec.value.decode()

backends/arm/arm_vela.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
import os
77
import struct
8-
import subprocess
98
import tempfile
109

1110
from typing import List
1211

1312
import numpy as np
13+
from ethosu.vela import vela
1414

1515

1616
# Pack either input or output tensor block, compose the related arrays into
@@ -38,21 +38,17 @@ def vela_compile(tosa_graph, args: List[str]):
3838
with tempfile.TemporaryDirectory() as tmpdir:
3939
tosaname = "out.tosa"
4040
flatbuffer = tosa_graph.serialize()
41-
with open(os.path.join(tmpdir, tosaname), "wb") as f:
41+
tosa_path = os.path.join(tmpdir, tosaname)
42+
with open(tosa_path, "wb") as f:
4243
f.write(flatbuffer)
4344

4445
# invoke vela
45-
vela_command = f"cd {tmpdir}; vela {' '.join(args)} {tosaname}"
46-
try:
47-
subprocess.run([vela_command], shell=True, check=True, capture_output=True)
48-
except subprocess.CalledProcessError as process_error:
49-
raise RuntimeError(
50-
f"Vela compiler ('{vela_command}') failed with error:\n \
51-
{process_error.stderr.decode()}\n \
52-
Stdout:\n{process_error.stdout.decode()}"
53-
)
54-
55-
np_path = os.path.join(tmpdir, "output", "out_sg0_vela.npz")
46+
output_dir = os.path.join(tmpdir, "output")
47+
args.append(f"--output-dir={output_dir}")
48+
args.append(tosa_path)
49+
vela.main(" ".join(args).split(" "))
50+
51+
np_path = os.path.join(output_dir, "out_sg0_vela.npz")
5652
blocks = b""
5753

5854
with np.load(np_path, allow_pickle=False) as data:

backends/arm/operators/TARGETS

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "node_visitor",
5+
srcs = ["node_visitor.py"],
6+
typing = True,
7+
deps = [
8+
"//executorch/backends/arm:tosa_mapping",
9+
],
10+
)
11+
12+
python_library(
13+
name = "ops",
14+
srcs = glob(["op_*.py"]),
15+
typing = True,
16+
deps = [
17+
"fbsource//third-party/serialization_lib/python/tosa:tosa",
18+
":node_visitor",
19+
"//executorch/backends/arm:tosa_mapping",
20+
"//executorch/backends/arm:tosa_quant_utils",
21+
"//executorch/backends/arm:tosa_utils",
22+
"//executorch/exir:lib",
23+
],
24+
)
25+
26+
python_library(
27+
name = "lib",
28+
srcs = ["__init__.py"],
29+
typing = True,
30+
deps = [
31+
":node_visitor",
32+
":ops",
33+
],
34+
)

backends/arm/operators/op_bmm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def define_node(
7272
build_rescale(
7373
tosa_fb=tosa_graph,
7474
scale=final_output_scale,
75+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `bmm_result` is undefined, or not always defined.
7576
input_node=bmm_result,
7677
output_name=output.name,
7778
output_type=ts.DType.INT8,

backends/arm/operators/op_conv2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
5-
from typing import List
5+
from typing import cast, List
66

77
import serializer.tosa_serializer as ts
88
import torch
@@ -156,11 +156,12 @@ def define_node(
156156
# integer value domain of the next op. Otherwise return float32 output.
157157
if is_quant_node:
158158
# Get scale_factor from input, weight, and output.
159-
_, input_scale, _, _, _, _ = getNodeArgs(node.args[0])
160-
_, weight_scale, _, _, _, _ = getNodeArgs(node.args[1])
159+
_, input_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[0]))
160+
_, weight_scale, _, _, _, _ = getNodeArgs(cast(torch.fx.Node, node.args[1]))
161161
_, output_scale, output_zp, _, _, _ = getNodeArgs(list(node.users)[0])
162162
build_rescale_conv_output(
163163
tosa_graph,
164+
# pyre-fixme[61]: Uninitialized local [61]: Local variable `conv2d_res` is undefined, or not always defined.
164165
conv2d_res,
165166
output.name,
166167
actual_out_type,

backends/arm/operators/op_mm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def define_node(
9696
build_rescale(
9797
tosa_fb=tosa_graph,
9898
scale=final_output_scale,
99+
# pyre-ignore[61]: Uninitialized local [61]: Local variable `reshape_intermediate` is undefined, or not always defined.
99100
input_node=reshape_intermediate,
100101
output_name=output.name,
101102
output_type=ts.DType.INT8,

backends/arm/operators/op_mul.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from typing import List
6+
from typing import cast, List
77

88
import executorch.backends.arm.tosa_quant_utils as tqutils
99
import executorch.backends.arm.tosa_utils as tutils
@@ -35,8 +35,12 @@ def define_node(
3535
if is_quant_node:
3636
input_A = inputs[0]
3737
input_B = inputs[1]
38-
input_A_qargs = tqutils.get_quant_node_args(node.args[0])
39-
input_B_qargs = tqutils.get_quant_node_args(node.args[1])
38+
input_A_qargs = tqutils.get_quant_node_args(
39+
cast(torch.fx.Node, node.args[0])
40+
)
41+
input_B_qargs = tqutils.get_quant_node_args(
42+
cast(torch.fx.Node, node.args[1])
43+
)
4044

4145
input_A.shape = tutils.tosa_shape(input_A.shape, input_A.dim_order)
4246
input_B.shape = tutils.tosa_shape(input_B.shape, input_B.dim_order)

backends/arm/operators/op_output.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
55

6+
from typing import cast
7+
68
import serializer.tosa_serializer as ts
79
import torch
810

@@ -11,7 +13,7 @@ def process_output(
1113
node: torch.fx.Node,
1214
tosa_graph: ts.TosaSerializer,
1315
):
14-
for output in node.args[0]:
16+
for output in cast(tuple[torch.fx.Node, ...], node.args[0]):
1517
tosa_graph.addOutputTensor(
1618
tosa_graph.currRegion.currBasicBlock.tensors[output.name]
1719
)

backends/arm/operators/op_view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66

77
import serializer.tosa_serializer as ts
88
import torch
9+
import tosa.Op as TosaOp
910

1011
from executorch.backends.arm.operators.node_visitor import (
1112
NodeVisitor,
1213
register_node_visitor,
1314
)
1415
from executorch.backends.arm.tosa_mapping import TosaArg
1516
from executorch.backends.arm.tosa_utils import tosa_shape
16-
from serializer.tosa_serializer import TosaOp
1717

1818

1919
@register_node_visitor

0 commit comments

Comments
 (0)