Skip to content

Commit c6fcfb9

Browse files
committed
Update on "make to_edge support etrecord generation"
This support to_edge export flow etrecord generation supportive. Details can be found in #12925 Differential Revision: [D79707919](https://our.internmc.facebook.com/intern/diff/D79707919/) [ghstack-poisoned]
2 parents 84b3e6f + 32ab32a commit c6fcfb9

File tree

80 files changed

+2743
-446
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+2743
-446
lines changed

.github/workflows/trunk.yml

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,71 @@ jobs:
7878
mkdir -p zephyr_scratch/
7979
cd zephyr_scratch
8080
export ZEPHYR_PROJ_ROOT=$(realpath $(pwd))
81+
export ARM_FVP_TUTORIALS_ROOT=$ZEPHYR_PROJ_ROOT/zephyr/samples/modules/executorch/arm-fvp-tutorials
8182
83+
# TODO @Bujji: Should see if this can be moved into the docker image itself
8284
download_arm_zephyr_sdk
8385
./zephyr-sdk-0.16.0/setup.sh -c -t arm-zephyr-eabi
84-
8586
cd $ZEPHYR_PROJ_ROOT
8687
setup_zephyr_et_module
8788
89+
# Run setup scripts for Arm FVP and Arm AOT Compilation
8890
cd $ZEPHYR_PROJ_ROOT/modules/lib/executorch
8991
install_executorch "--use-pt-pinned-commit"
9092
.ci/scripts/setup-arm-baremetal-tools.sh --target-toolchain zephyr
9193
source examples/arm/ethos-u-scratch/setup_path.sh
9294
source $ZEPHYR_PROJ_ROOT/zephyr/zephyr-env.sh
93-
cd $ZEPHYR_PROJ_ROOT/zephyr/samples/modules/executorch/arm/hello_world
94-
west build -p always -b mps3/corstone300/fvp
95-
FVP_Corstone_SSE-300_Ethos-U55 -a build/zephyr/zephyr.elf -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='sim.out' -C cpu0.CFGITCMSZ=15 -C cpu0.CFGDTCMSZ=15 --simlimit 120
9695
97-
grep -qF "Output[0][0]: (float) 2.000000" sim.out
98-
exit_status=$? #store 0 if found (success), 1 if not (failure)
99-
exit $exit_status
96+
# Get the model as PTE
97+
python -m examples.arm.aot_arm_compiler \
98+
--model_name="${MODEL_NAME}" \
99+
--output="${MODEL_NAME}.pte"
100+
101+
# Generate the C-style header
102+
cd $ARM_FVP_TUTORIALS_ROOT
103+
python build_model.py \
104+
--executorch-root $ZEPHYR_PROJ_ROOT/modules/lib/executorch \
105+
--pte-file $ZEPHYR_PROJ_ROOT/modules/lib/executorch/${MODEL_NAME}.pte \
106+
--output-path $ARM_FVP_TUTORIALS_ROOT/models/${MODEL_NAME}/src/
107+
108+
cd $ARM_FVP_TUTORIALS_ROOT/models/${MODEL_NAME}/
109+
110+
# Build the zephyr elf
111+
west build -p always -b mps3/corstone300/fvp -- \
112+
-DET_PTE_FILE_PATH_FOR_SELECTIVE_BUILD=$ZEPHYR_PROJ_ROOT/modules/lib/executorch/${MODEL_NAME}.pte
113+
114+
# Run the simulation
115+
FVP_Corstone_SSE-300_Ethos-U55 -a build/zephyr/zephyr.elf \
116+
-C mps3_board.visualisation.disable-visualisation=1 \
117+
-C mps3_board.telnetterminal0.start_telnet=0 \
118+
-C mps3_board.uart0.out_file='sim.out' \
119+
-C cpu0.CFGITCMSZ=15 \
120+
-C cpu0.CFGDTCMSZ=15 \
121+
--simlimit 120
122+
123+
# Disable exit on error
124+
set +e
125+
# Report failure if any of the ouptut verification checks fail
126+
# store 0 if found (failure), 1 if not (success)
127+
grep -qF "ERROR" sim.out
128+
exit_status=$?
129+
if [[ "$exit_status" -eq "0" ]]; then
130+
cat sim.out
131+
set -e
132+
exit 1
133+
fi
134+
135+
# Report fail if simulation does not complete successfully
136+
# store 0 if found (success), 1 if not (failure)
137+
grep -qF "SUCCESS: Program complete, exiting." sim.out
138+
exit_status=$?
139+
if [[ "$exit_status" -eq "1" ]]; then
140+
cat sim.out
141+
set -e
142+
exit 1
143+
fi
144+
# Re-enable exit on error
145+
set -e
100146
101147
test-models-linux-aarch64:
102148
name: test-models-linux-aarch64

backends/arm/_passes/decompose_avg_pool2d.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def call_operator(self, op, args, kwargs, meta):
4545
x = args[0]
4646
kernel_h, kernel_w = args[1]
4747
kernel_size = kernel_h * kernel_w
48-
stride_h, stride_w = args[2]
48+
if len(args) > 2 and args[2] is not None:
49+
stride_h, stride_w = args[2]
50+
else:
51+
stride_h, stride_w = kernel_h, kernel_w
4952
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
5053
ceil_mode = args[4] if len(args) > 4 else False
5154
count_include_pad = args[5] if len(args) > 5 else True
@@ -108,7 +111,14 @@ def call_operator(self, op, args, kwargs, meta):
108111
x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
109112
new_pad_h = 0
110113

111-
avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
114+
avgpool_args = (
115+
x,
116+
args[1],
117+
[stride_h, stride_w],
118+
[new_pad_h, new_pad_w],
119+
ceil_mode,
120+
False,
121+
)
112122
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)
113123

114124
# Multiply by factor (kernel_size / divisor_override) if divisor_override

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1918

20-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
19+
from executorch.backends.arm._passes.quant_args import QuantArgs
20+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2121

2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.dialects.edge._ops import EdgeOpOverload

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
910
from executorch.backends.arm.constants import Q_OPS
10-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch.fx import Node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.quant_args import QuantArgs
1213
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch import Tensor
1616
from torch.fx import GraphModule, Node

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
13+
from executorch.backends.arm._passes.quant_args import QuantArgs
1414
from executorch.exir import ExportedProgram
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops

backends/arm/_passes/quant_args.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright 2025 Arm Limited and/or its affiliates.
2+
#
3+
# This source code is licensed under the BSD-style license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from typing import Any, cast, NamedTuple
6+
7+
import torch
8+
from executorch.exir.dialects._ops import ops as exir_ops
9+
10+
exir_ops = cast(Any, exir_ops)
11+
from executorch.backends.arm.constants import PER_CHANNEL_QDQ_OPS, PER_TENSOR_QDQ_OPS
12+
from torch import Tensor
13+
14+
15+
class QuantArgs(NamedTuple):
16+
scale: list[float] | float
17+
zp: list[int] | int
18+
qmin: int
19+
qmax: int
20+
dtype: torch.dtype
21+
axis: int = 0
22+
per_channel: bool = False
23+
24+
def quantize_value(self, x: torch.Tensor | float) -> Tensor:
25+
"""Quantizes the input tensor or value to a quantized tensor. If the input is
26+
not a tensor, it is converted to a tensor first. If self.per_channel is True,
27+
the quantization is done per channel, otherwise it is done per tensor.
28+
"""
29+
if not isinstance(x, torch.Tensor):
30+
x = torch.Tensor([x])
31+
x = x.to(torch.float32)
32+
if self.per_channel:
33+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_channel.default
34+
args = (
35+
x,
36+
torch.tensor(self.scale),
37+
torch.tensor(self.zp),
38+
self.axis,
39+
self.qmin,
40+
self.qmax,
41+
self.dtype,
42+
)
43+
else:
44+
q_op = exir_ops.edge.quantized_decomposed.quantize_per_tensor.default
45+
args = (x, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
46+
return q_op(*args)
47+
48+
def dequantize_value(self, qx: torch.Tensor) -> torch.Tensor:
49+
"""Dequantizes the input tensor or value to a dequantized tensor If the input
50+
is not a tensor, it is converted to a tensor first. If self.per_channel is True,
51+
the dequantization is done per channel, otherwise it is done per tensor.
52+
"""
53+
if self.per_channel:
54+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_channel.default
55+
args = (
56+
qx,
57+
torch.tensor(self.scale),
58+
torch.tensor(self.zp),
59+
self.axis,
60+
self.qmin,
61+
self.qmax,
62+
self.dtype,
63+
)
64+
else:
65+
dq_op = exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default
66+
args = (qx, self.scale, self.zp, self.qmin, self.qmax, self.dtype) # type: ignore[assignment]
67+
return dq_op(*args)
68+
69+
@classmethod
70+
def from_operator(cls, op, args):
71+
if op in PER_TENSOR_QDQ_OPS:
72+
return cls(
73+
scale=cast(float, args[1]),
74+
zp=cast(int, args[2]),
75+
qmin=cast(int, args[3]),
76+
qmax=cast(int, args[4]),
77+
dtype=cast(torch.dtype, args[5]),
78+
axis=0,
79+
per_channel=False,
80+
)
81+
elif op in PER_CHANNEL_QDQ_OPS:
82+
return cls(
83+
scale=cast(list[float], args[1].tolist()),
84+
zp=cast(list[int], args[2].tolist()),
85+
axis=cast(int, args[3]),
86+
qmin=cast(int, args[4]),
87+
qmax=cast(int, args[5]),
88+
dtype=cast(torch.dtype, args[6]),
89+
per_channel=True,
90+
)
91+
else:
92+
# We're only handling per tensor and per channel quantization
93+
raise NotImplementedError(f"Unsupported quantization operation: {op}")
94+
95+
def get_scale_per_tensor(self) -> float:
96+
if not isinstance(self.scale, float):
97+
raise TypeError(
98+
f"Expected scale {self.scale} to be a float but found scale of "
99+
f"type {type(self.scale)}"
100+
)
101+
return self.scale
102+
103+
def get_zp_per_tensor(self) -> int:
104+
if not isinstance(self.zp, int):
105+
raise TypeError(
106+
f"Expected zero point {self.zp} to be an int but found zp of "
107+
f"type {type(self.zp)}"
108+
)
109+
return self.zp
110+
111+
def get_scale_per_channel(self) -> list[float]:
112+
if not isinstance(self.scale, list):
113+
raise TypeError(
114+
f"Expected scale {self.scale} to be a list but found scale of "
115+
f"type {type(self.scale)}"
116+
)
117+
return self.scale
118+
119+
def get_zp_per_channel(self) -> list[int]:
120+
if not isinstance(self.zp, list):
121+
raise TypeError(
122+
f"Expected zero point {self.zp} to be a list but found zp of "
123+
f"type {type(self.zp)}"
124+
)
125+
return self.zp

backends/arm/operator_support/ethos_u55_support.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,8 @@ class EthosU55NotSupported(OperatorSupportBase):
149149
exir_ops.edge.aten.ne.Scalar,
150150
exir_ops.edge.aten.flip.default, # REVERSE
151151
exir_ops.edge.aten.grid_sampler_2d, # GATHER
152+
exir_ops.edge.aten.index.Tensor, # GATHER
153+
exir_ops.edge.aten.index_select.default, # GATHER
152154
exir_ops.edge.aten.scatter.src,
153155
exir_ops.edge.aten.scatter.value,
154156
exir_ops.edge.aten.select_scatter.default,

backends/arm/runtime/VGFBackend.cpp

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -264,15 +264,60 @@ VkResult vkml_allocate_basics(
264264
.engineVersion = 0,
265265
.apiVersion = VK_API_VERSION_1_3,
266266
};
267+
268+
std::vector<const char*> requested_extensions;
269+
VkInstanceCreateFlags instance_flags = 0;
270+
271+
#ifdef __APPLE__
272+
instance_flags |= VK_INSTANCE_CREATE_ENUMERATE_PORTABILITY_BIT_KHR;
273+
274+
uint32_t extension_count = 0;
275+
result = vkEnumerateInstanceExtensionProperties(
276+
nullptr, &extension_count, nullptr);
277+
278+
if (result != VK_SUCCESS) {
279+
ET_LOG(Error, "Failed to enumerate instance extensions");
280+
return result;
281+
}
282+
283+
std::vector<VkExtensionProperties> extension_properties(extension_count);
284+
result = vkEnumerateInstanceExtensionProperties(
285+
nullptr, &extension_count, extension_properties.data());
286+
287+
if (result != VK_SUCCESS) {
288+
ET_LOG(Error, "Failed to enumerate instance extensions");
289+
return result;
290+
}
291+
292+
if (std::any_of(
293+
extension_properties.begin(),
294+
extension_properties.end(),
295+
[](const auto& extension) {
296+
return strcmp(
297+
VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME,
298+
extension.extensionName) == 0;
299+
})) {
300+
requested_extensions.push_back(
301+
VK_KHR_PORTABILITY_ENUMERATION_EXTENSION_NAME);
302+
}
303+
304+
if (requested_extensions.empty()) {
305+
ET_LOG(Error, "VK_KHR_portability_enumeration not found");
306+
}
307+
308+
#endif
309+
267310
VkInstanceCreateInfo instance_info{
268311
.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO,
269312
.pNext = nullptr,
270-
.flags = 0,
313+
.flags = instance_flags,
271314
.pApplicationInfo = &app_info,
272-
0,
273-
nullptr,
274-
0,
275-
nullptr};
315+
.enabledLayerCount = 0,
316+
.ppEnabledLayerNames = nullptr,
317+
.enabledExtensionCount =
318+
static_cast<uint32_t>(requested_extensions.size()),
319+
.ppEnabledExtensionNames = requested_extensions.data(),
320+
};
276321
result = vkCreateInstance(&instance_info, nullptr, instance);
277322
if (result != VK_SUCCESS) {
278323
ET_LOG(Error, "Failed to create VkInstance");

0 commit comments

Comments
 (0)