Skip to content

Commit fdfeaa4

Browse files
authored
Arm Backend: Enable Pybindings for tosa_serialization lib (#15356)
These bindings should drastically increase serialization performance, both speed wise and memory usage wise. This patch also enables serialization of tosa flatbuffers that are over 2GB in size. - Ensure each tosa operator has an attribute. - Replace TosaOp.Op() calls with Op enum - Convert tosa clamp data to numpy int8 to allow c++ byte serialization - Replace "import tosa_serializer.serializer" with "import tosa_serializer" for pybind - Remove serialize to json. - Remove Darwin support workaround in mlsdk dependencies - Update mlsdk to latest tag, but overwrite the tosa_mlir_translator to a version that supports offset buffers - Update tosa-tools to the monorepo - Update slice to clamp end and start dim. This avoids the object not having a number field from being too large or small. cc @freddan80 @per @zingo @oscarandersson8218 @digantdesai Signed-off-by: Ryan O'Shea <[email protected]>
1 parent bda5173 commit fdfeaa4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+489
-345
lines changed

.mypy.ini

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ ignore_missing_imports = True
8383
[mypy-tosa_tools.*]
8484
ignore_missing_imports = True
8585

86+
[mypy-tosa_serializer]
87+
ignore_missing_imports = True
88+
89+
[mypy-tosa_serializer.*]
90+
ignore_missing_imports = True
91+
8692
[mypy-setuptools.*]
8793
ignore_missing_imports = True
8894

backends/arm/common/debug.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
import os
88
from typing import Optional
99

10-
import serializer.tosa_serializer as ts
1110
import torch
11+
12+
import tosa_serializer as ts
1213
from executorch.exir.print_program import inspect_node
1314

1415
logger = logging.getLogger(__name__)
@@ -50,29 +51,20 @@ def get_node_debug_info(
5051
return output
5152

5253

53-
# Output TOSA flatbuffer and test harness file
54-
def debug_tosa_dump(tosa_graph: ts.TosaSerializer, path: str, suffix: str = ""):
54+
# Output TOSA flatbuffer for debugging
55+
def debug_tosa_dump(tosa_graph: bytes, path: str, suffix: str = ""):
5556
filename = f"output{suffix}.tosa"
5657

5758
logger.info(f"Emitting debug output to: {path=}, {suffix=}")
5859

5960
os.makedirs(path, exist_ok=True)
6061

61-
fb = tosa_graph.serialize()
62-
js = tosa_graph.writeJson(filename)
63-
6462
filepath_tosa_fb = os.path.join(path, filename)
6563
with open(filepath_tosa_fb, "wb") as f:
66-
f.write(fb)
64+
f.write(tosa_graph)
6765
if not os.path.exists(filepath_tosa_fb):
6866
raise IOError("Failed to write TOSA flatbuffer")
6967

70-
filepath_desc_json = os.path.join(path, f"desc{suffix}.json")
71-
with open(filepath_desc_json, "w") as f:
72-
f.write(js)
73-
if not os.path.exists(filepath_desc_json):
74-
raise IOError("Failed to write TOSA JSON")
75-
7668

7769
def debug_fail(
7870
node,
@@ -81,7 +73,7 @@ def debug_fail(
8173
path: Optional[str] = None,
8274
):
8375
logger.warning("Internal error due to poorly handled node:")
84-
if tosa_graph is not None and path is not None:
85-
debug_tosa_dump(tosa_graph, path)
76+
if tosa_graph is not None and path:
77+
debug_tosa_dump(tosa_graph.serialize(), path)
8678
logger.warning(f"Debug output captured in '{path}'.")
8779
debug_node(node, graph_module)

backends/arm/debug/schema.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
from dataclasses import asdict, dataclass
1111
from typing import Any, Optional
1212

13-
import serializer.tosa_serializer as ts
1413
import torch
14+
import tosa_serializer as ts
1515

1616
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1717

@@ -114,23 +114,18 @@ def to_dict(self) -> dict[str, Any]:
114114
class DebugHook:
115115
def __init__(self, debug_mode: ArmCompileSpec.DebugMode) -> None:
116116
self._debug_events: list[DebugSchema] = []
117-
self.__op_id_to_name = {}
118117
self.mode = debug_mode
119118

120-
# Build up a mapping from TOSA 1.0 operator IDs to their names
121-
for name, val in vars(ts.Op).items():
122-
self.__op_id_to_name[val] = name
123-
124-
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: int) -> DebugSchema:
119+
def add(self, node: torch.fx.Node, tosa_op: Any, tosa_op_id: ts.Op) -> DebugSchema:
125120
tosa_debug_info = None
126121

127122
# If the debug data is being embedded into the TOSA flatbuffer
128123
# do not collect TOSADebugSchema data, it's redundent
129124
if self.mode != ArmCompileSpec.DebugMode.TOSA:
130125
tosa_debug_info = TosaDebugSchema(
131126
node_name=str(tosa_op),
132-
operator_name=self.__op_id_to_name[tosa_op_id],
133-
operator_id=tosa_op_id,
127+
operator_name=str(tosa_op_id),
128+
operator_id=int(tosa_op_id),
134129
)
135130

136131
aten_debug_info = ATenDebugSchema.from_node(node)

backends/arm/ethosu/backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,15 @@ def _compile_tosa_flatbuffer(
5151
"compile_flags are required in the CompileSpec list for EthosUBackend"
5252
)
5353

54+
# Vela tooling only supports flatbuffers up to 2 GiB.
55+
max_flatbuffer_size = 2 * 1024 * 1024 * 1024
56+
flatbuffer_size = len(tosa_flatbuffer)
57+
if flatbuffer_size > max_flatbuffer_size:
58+
raise RuntimeError(
59+
"TOSA flatbuffer is too large for Vela "
60+
f"({flatbuffer_size} bytes > {max_flatbuffer_size} bytes limit)."
61+
)
62+
5463
# Pass on the TOSA flatbuffer to the vela compiler.
5564
binary = vela_compile(
5665
tosa_flatbuffer,

backends/arm/operators/node_visitor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Any, Dict, List, Optional
1010

1111
import torch
12+
import tosa_serializer as ts
1213

1314
from executorch.backends.arm.common.arm_compile_spec import ArmCompileSpec
1415
from executorch.backends.arm.debug.schema import DebugHook
@@ -46,12 +47,12 @@ def _serialize_operator(
4647
self,
4748
node: torch.fx.Node,
4849
tosa_graph: Any,
49-
tosa_op: Any,
50+
tosa_op: ts.Op,
5051
inputs: List[str],
5152
outputs: List[str],
5253
attributes: Optional[Any] = None,
5354
) -> None:
54-
op_location = ""
55+
op_location = ts.TosaOpLocation()
5556
if self.debug_hook:
5657
debug_info = self.debug_hook.add(
5758
node,
@@ -60,7 +61,7 @@ def _serialize_operator(
6061
)
6162

6263
if self.debug_hook.mode == ArmCompileSpec.DebugMode.TOSA:
63-
op_location = json.dumps(debug_info.to_dict())
64+
op_location.text = json.dumps(debug_info.to_dict())
6465

6566
tosa_graph.addOperator(
6667
tosa_op,

backends/arm/operators/op_abs.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77
from typing import Any, List
88

9-
import serializer.tosa_serializer as ts
9+
import tosa_serializer as ts
1010

1111
from executorch.backends.arm.operators.node_visitor import (
1212
NodeVisitor,
@@ -48,11 +48,13 @@ def define_node(
4848
output.tosa_spec,
4949
)
5050

51-
tosa_graph.addOperator(
52-
ts.TosaOp.Op().ABS,
53-
[
54-
inputs[0].name,
55-
],
51+
attr = ts.TosaSerializerAttribute()
52+
attr.AbsAttribute()
53+
self._serialize_operator(
54+
node,
55+
tosa_graph,
56+
ts.Op.ABS,
57+
[inputs[0].name],
5658
[output.name],
57-
None,
59+
attr,
5860
)

backends/arm/operators/op_add.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
import executorch.backends.arm.tosa.quant_utils as tqutils
1111
import executorch.backends.arm.tosa.utils as tutils
12-
import serializer.tosa_serializer as ts
12+
import tosa_serializer as ts
1313

1414
from executorch.backends.arm.operators.node_visitor import (
1515
NodeVisitor,
@@ -81,15 +81,16 @@ def define_node(
8181
add_output = output
8282

8383
input1, input2 = rescaled_inputs
84-
84+
attr = ts.TosaSerializerAttribute()
85+
attr.AddAttribute()
8586
# Do the INT32 Add
8687
self._serialize_operator(
8788
node,
8889
tosa_graph,
89-
ts.TosaOp.Op().ADD,
90+
ts.Op.ADD,
9091
[input1.name, input2.name],
9192
[add_output.name],
92-
None,
93+
attr,
9394
)
9495

9596
if output.dtype == ts.DType.INT8:
@@ -143,13 +144,14 @@ def define_node(
143144
)
144145

145146
input1, input2 = inputs
146-
147+
attr = ts.TosaSerializerAttribute()
148+
attr.AddAttribute()
147149
# FP lowering
148150
self._serialize_operator(
149151
node,
150152
tosa_graph,
151-
ts.TosaOp.Op().ADD,
153+
ts.Op.ADD,
152154
[input1.name, input2.name],
153155
[output.name],
154-
None,
156+
attr,
155157
)

backends/arm/operators/op_amax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Any, List
66

7-
import serializer.tosa_serializer as ts
7+
import tosa_serializer as ts
88

99
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1010
from executorch.backends.arm.operators.node_visitor import (
@@ -60,11 +60,12 @@ def define_node(
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()
63-
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=1)
63+
nan_mode = ts.NanPropagationMode.PROPAGATE
64+
attr.ReduceMaxAttribute(axis=input.dim_order.index(dim), nan_mode=nan_mode)
6465
self._serialize_operator(
6566
node,
6667
tosa_graph,
67-
ts.TosaOp.Op().REDUCE_MAX,
68+
ts.Op.REDUCE_MAX,
6869
[input.name],
6970
[output.name],
7071
attr,

backends/arm/operators/op_amin.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# LICENSE file in the root directory of this source tree.
55
from typing import Any, List
66

7-
import serializer.tosa_serializer as ts
7+
import tosa_serializer as ts
88

99
from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor
1010
from executorch.backends.arm.operators.node_visitor import (
@@ -60,11 +60,13 @@ def define_node(
6060
)
6161

6262
attr = ts.TosaSerializerAttribute()
63-
attr.ReduceMinAttribute(axis=input.dim_order.index(dim), nan_mode=1)
63+
attr.ReduceMinAttribute(
64+
axis=input.dim_order.index(dim), nan_mode=ts.NanPropagationMode.PROPAGATE
65+
)
6466
self._serialize_operator(
6567
node,
6668
tosa_graph,
67-
ts.TosaOp.Op().REDUCE_MIN,
69+
ts.Op.REDUCE_MIN,
6870
[input.name],
6971
[output.name],
7072
attr,

backends/arm/operators/op_any.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-unsafe
77
from typing import Any, cast, List
88

9-
import serializer.tosa_serializer as ts
9+
import tosa_serializer as ts
1010

1111
from executorch.backends.arm.operators.node_visitor import ( # type: ignore
1212
NodeVisitor,
@@ -55,7 +55,7 @@ def define_node(
5555
self._serialize_operator(
5656
node,
5757
tosa_graph,
58-
ts.TosaOp.Op().REDUCE_ANY,
58+
ts.Op.REDUCE_ANY,
5959
[inputs[0].name],
6060
[output.name],
6161
attr,

0 commit comments

Comments
 (0)