Skip to content

Commit 0a4fe14

Browse files
authored
Merge branch 'main' into add-dim-order-clone-kernel
2 parents 19d14e1 + a8fe653 commit 0a4fe14

File tree

127 files changed

+3759
-1276
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+3759
-1276
lines changed

.github/workflows/trunk.yml

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,71 @@ jobs:
7878
mkdir -p zephyr_scratch/
7979
cd zephyr_scratch
8080
export ZEPHYR_PROJ_ROOT=$(realpath $(pwd))
81+
export ARM_FVP_TUTORIALS_ROOT=$ZEPHYR_PROJ_ROOT/zephyr/samples/modules/executorch/arm-fvp-tutorials
8182
83+
# TODO @Bujji: Should see if this can be moved into the docker image itself
8284
download_arm_zephyr_sdk
8385
./zephyr-sdk-0.16.0/setup.sh -c -t arm-zephyr-eabi
84-
8586
cd $ZEPHYR_PROJ_ROOT
8687
setup_zephyr_et_module
8788
89+
# Run setup scripts for Arm FVP and Arm AOT Compilation
8890
cd $ZEPHYR_PROJ_ROOT/modules/lib/executorch
8991
install_executorch "--use-pt-pinned-commit"
9092
.ci/scripts/setup-arm-baremetal-tools.sh --target-toolchain zephyr
9193
source examples/arm/ethos-u-scratch/setup_path.sh
9294
source $ZEPHYR_PROJ_ROOT/zephyr/zephyr-env.sh
93-
cd $ZEPHYR_PROJ_ROOT/zephyr/samples/modules/executorch/arm/hello_world
94-
west build -p always -b mps3/corstone300/fvp
95-
FVP_Corstone_SSE-300_Ethos-U55 -a build/zephyr/zephyr.elf -C mps3_board.visualisation.disable-visualisation=1 -C mps3_board.telnetterminal0.start_telnet=0 -C mps3_board.uart0.out_file='sim.out' -C cpu0.CFGITCMSZ=15 -C cpu0.CFGDTCMSZ=15 --simlimit 120
9695
97-
grep -qF "Output[0][0]: (float) 2.000000" sim.out
98-
exit_status=$? #store 0 if found (success), 1 if not (failure)
99-
exit $exit_status
96+
# Get the model as PTE
97+
python -m examples.arm.aot_arm_compiler \
98+
--model_name="${MODEL_NAME}" \
99+
--output="${MODEL_NAME}.pte"
100+
101+
# Generate the C-style header
102+
cd $ARM_FVP_TUTORIALS_ROOT
103+
python build_model.py \
104+
--executorch-root $ZEPHYR_PROJ_ROOT/modules/lib/executorch \
105+
--pte-file $ZEPHYR_PROJ_ROOT/modules/lib/executorch/${MODEL_NAME}.pte \
106+
--output-path $ARM_FVP_TUTORIALS_ROOT/models/${MODEL_NAME}/src/
107+
108+
cd $ARM_FVP_TUTORIALS_ROOT/models/${MODEL_NAME}/
109+
110+
# Build the zephyr elf
111+
west build -p always -b mps3/corstone300/fvp -- \
112+
-DET_PTE_FILE_PATH_FOR_SELECTIVE_BUILD=$ZEPHYR_PROJ_ROOT/modules/lib/executorch/${MODEL_NAME}.pte
113+
114+
# Run the simulation
115+
FVP_Corstone_SSE-300_Ethos-U55 -a build/zephyr/zephyr.elf \
116+
-C mps3_board.visualisation.disable-visualisation=1 \
117+
-C mps3_board.telnetterminal0.start_telnet=0 \
118+
-C mps3_board.uart0.out_file='sim.out' \
119+
-C cpu0.CFGITCMSZ=15 \
120+
-C cpu0.CFGDTCMSZ=15 \
121+
--simlimit 120
122+
123+
# Disable exit on error
124+
set +e
125+
# Report failure if any of the ouptut verification checks fail
126+
# store 0 if found (failure), 1 if not (success)
127+
grep -qF "ERROR" sim.out
128+
exit_status=$?
129+
if [[ "$exit_status" -eq "0" ]]; then
130+
cat sim.out
131+
set -e
132+
exit 1
133+
fi
134+
135+
# Report fail if simulation does not complete successfully
136+
# store 0 if found (success), 1 if not (failure)
137+
grep -qF "SUCCESS: Program complete, exiting." sim.out
138+
exit_status=$?
139+
if [[ "$exit_status" -eq "1" ]]; then
140+
cat sim.out
141+
set -e
142+
exit 1
143+
fi
144+
# Re-enable exit on error
145+
set -e
100146
101147
test-models-linux-aarch64:
102148
name: test-models-linux-aarch64

backends/apple/coreml/compiler/torch_ops.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
# coremltools than is used by ExecuTorch. Each op registered here should have a link to a PR in coremltools that adds
99
# the op to the coremltools library.
1010

11+
import numpy as np
1112
import torch as _torch
1213
from coremltools import _logger
1314
from coremltools.converters.mil.frontend import _utils
@@ -21,7 +22,6 @@
2122
transpose,
2223
unbind,
2324
)
24-
2525
from coremltools.converters.mil.frontend.torch.torch_op_registry import (
2626
register_torch_op,
2727
)
@@ -132,3 +132,43 @@ def dequantize_affine(context, node):
132132
name=node.name,
133133
)
134134
context.add(output, node.name)
135+
136+
137+
@register_torch_op(
138+
torch_alias=["quant::dequantize_codebook", "quant.dequantize_codebook"],
139+
override=False,
140+
)
141+
def dequantize_codebook(context, node):
142+
inputs = _get_inputs(context, node, expected=[4, 5])
143+
codes = inputs[0].val
144+
codebook = inputs[1].val
145+
nbits = inputs[2].val
146+
147+
# information in block_size is redundant with codebook.shape
148+
block_size = inputs[3].val # noqa: F841
149+
150+
assert len(codes.shape) == 2, "Only rank 2 inputs are supported"
151+
152+
# Assert codebook is as expected. codebook.dim() = codes.dim() + 2
153+
assert len(codebook.shape) == 4, "Only rank 4 inputs are supported for codebook"
154+
assert codebook.shape[0] == 1, "Only grouped_channel granularity is supported"
155+
n_luts = codebook.shape[1]
156+
assert (
157+
codes.shape[1] % n_luts == 0
158+
), "codes.shape[1] must be divisible by codebook.shape[1]"
159+
assert codebook.shape[2] == 2**nbits
160+
assert codebook.shape[3] == 1, "Only scalar look up values are supported"
161+
162+
if len(inputs) > 4:
163+
output_dtype = inputs[4].val
164+
out_np_dtype = NUM_TO_NUMPY_DTYPE[output_dtype]
165+
_logger.warning(
166+
f"Core ML ignores output_dtype {out_np_dtype} on torchao.dequantize_affine and instead uses the native precision."
167+
)
168+
169+
output = _utils._construct_constexpr_lut_op(
170+
codes.astype(np.int8),
171+
codebook,
172+
name=node.name,
173+
)
174+
context.add(output, node.name)

backends/apple/coreml/runtime/delegate/coreml_backend_delegate.mm

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,17 +88,17 @@
8888
ET_LOG(Error, "%s: DataType=%d is not supported", ETCoreMLStrings.delegateIdentifier.UTF8String, (int)tensor.scalar_type());
8989
return std::nullopt;
9090
}
91-
91+
9292
std::vector<ssize_t> strides(tensor.strides().begin(), tensor.strides().end());
9393
std::vector<size_t> shape(tensor.sizes().begin(), tensor.sizes().end());
94-
94+
9595
// If tensor is rank 0, wrap in rank 1
9696
// See https://github.com/apple/coremltools/blob/8.2/coremltools/converters/mil/frontend/torch/exir_utils.py#L73
9797
if (shape.size() == 0) {
9898
shape.push_back(1);
9999
strides.push_back(1);
100100
}
101-
101+
102102
MultiArray::MemoryLayout layout(dataType.value(), std::move(shape), std::move(strides));
103103
switch (argType) {
104104
case ArgType::Input: {
@@ -281,9 +281,11 @@ ModelLoggingOptions get_logging_options(BackendExecutionContext& context) {
281281
}
282282

283283
namespace {
284-
auto cls = CoreMLBackendDelegate();
285-
Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls};
286-
static auto success_with_compiler = register_backend(backend);
284+
#ifndef LAZY_LOAD_IOS_PYTORCH_INITIALIZER
285+
auto cls = CoreMLBackendDelegate();
286+
Backend backend{ETCoreMLStrings.delegateIdentifier.UTF8String, &cls};
287+
static auto success_with_compiler = register_backend(backend);
288+
#endif
287289
}
288290

289291
} // namespace coreml

backends/apple/coreml/test/test_torch_ops.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414

1515
from executorch.backends.apple.coreml.compiler import CoreMLBackend
1616
from executorch.backends.apple.coreml.partition import CoreMLPartitioner
17+
from executorch.exir.backend.utils import format_delegated_graph
18+
19+
from torchao.prototype.quantization.codebook_coreml import CodebookWeightOnlyConfig
1720
from torchao.quantization import IntxWeightOnlyConfig, PerAxis, PerGroup, quantize_
1821

1922

@@ -164,6 +167,61 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
164167
et_prog = delegated_program.to_executorch()
165168
self._compare_outputs(et_prog, model, example_inputs)
166169

170+
def test_dequantize_codebook_linear(self):
171+
model, example_inputs = self._get_test_model()
172+
quantize_(
173+
model,
174+
CodebookWeightOnlyConfig(dtype=torch.uint2, block_size=[-1, 16]),
175+
)
176+
ep = torch.export.export(model, example_inputs)
177+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
178+
delegated_program = executorch.exir.to_edge_transform_and_lower(
179+
ep,
180+
partitioner=[self._coreml_partitioner()],
181+
)
182+
for node in delegated_program.exported_program().graph.nodes:
183+
if node.op == "call_function":
184+
assert node.target.__name__ in [
185+
"executorch_call_delegate",
186+
"getitem",
187+
], f"Got unexpected node target after delegation: {node.target.__name__}"
188+
189+
assert (
190+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
191+
in format_delegated_graph(delegated_program.exported_program().graph_module)
192+
)
193+
194+
et_prog = delegated_program.to_executorch()
195+
self._compare_outputs(et_prog, model, example_inputs)
196+
197+
def test_dequantize_codebook_embedding(self):
198+
model, example_inputs = self._get_test_model()
199+
quantize_(
200+
model,
201+
CodebookWeightOnlyConfig(dtype=torch.uint3, block_size=[-1, 16]),
202+
lambda m, fqn: isinstance(m, torch.nn.Embedding),
203+
)
204+
ep = torch.export.export(model, example_inputs)
205+
assert "torch.ops.quant.dequantize_codebook.default" in ep.graph_module.code
206+
delegated_program = executorch.exir.to_edge_transform_and_lower(
207+
ep,
208+
partitioner=[self._coreml_partitioner()],
209+
)
210+
for node in delegated_program.exported_program().graph.nodes:
211+
if node.op == "call_function":
212+
assert node.target.__name__ in [
213+
"executorch_call_delegate",
214+
"getitem",
215+
], f"Got unexpected node target after delegation: {node.target.__name__}"
216+
217+
assert (
218+
"executorch.exir.dialects.edge._ops.quant.dequantize_codebook.default"
219+
in format_delegated_graph(delegated_program.exported_program().graph_module)
220+
)
221+
222+
et_prog = delegated_program.to_executorch()
223+
self._compare_outputs(et_prog, model, example_inputs)
224+
167225

168226
if __name__ == "__main__":
169227
test_runner = TestTorchOps()
@@ -172,3 +230,5 @@ def test_dequantize_affine_c8w_embedding_b4w_linear(self):
172230
test_runner.test_dequantize_affine_c4w_embedding()
173231
test_runner.test_dequantize_affine_c4w_linear()
174232
test_runner.test_dequantize_affine_c8w_embedding_b4w_linear()
233+
test_runner.test_dequantize_codebook_linear()
234+
test_runner.test_dequantize_codebook_embedding()

backends/arm/_passes/decompose_avg_pool2d.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ def call_operator(self, op, args, kwargs, meta):
4545
x = args[0]
4646
kernel_h, kernel_w = args[1]
4747
kernel_size = kernel_h * kernel_w
48-
stride_h, stride_w = args[2]
48+
if len(args) > 2 and args[2] is not None:
49+
stride_h, stride_w = args[2]
50+
else:
51+
stride_h, stride_w = kernel_h, kernel_w
4952
pad_h, pad_w = new_pad_h, new_pad_w = args[3] if len(args) > 3 else (0, 0)
5053
ceil_mode = args[4] if len(args) > 4 else False
5154
count_include_pad = args[5] if len(args) > 5 else True
@@ -108,7 +111,14 @@ def call_operator(self, op, args, kwargs, meta):
108111
x = super().call_operator(cat_op, (cat_nodes, 2), kwargs, meta)
109112
new_pad_h = 0
110113

111-
avgpool_args = (x, args[1], args[2], [new_pad_h, new_pad_w], ceil_mode, False)
114+
avgpool_args = (
115+
x,
116+
args[1],
117+
[stride_h, stride_w],
118+
[new_pad_h, new_pad_w],
119+
ceil_mode,
120+
False,
121+
)
112122
x = super().call_operator(avgpool_op, avgpool_args, kwargs, meta)
113123

114124
# Multiply by factor (kernel_size / divisor_override) if divisor_override

backends/arm/_passes/decompose_grouped_conv.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from copy import copy
77

88
import torch
9-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
1010
from executorch.exir.dialects._ops import ops as exir_ops
1111
from executorch.exir.pass_base import ExportPass
1212

backends/arm/_passes/fold_qdq_with_annotated_qparams_pass.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515
get_param_tensor,
1616
is_param_node,
1717
)
18-
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
1918

20-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
19+
from executorch.backends.arm._passes.quant_args import QuantArgs
20+
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
2121

2222
from executorch.exir.dialects._ops import ops as exir_ops
2323
from executorch.exir.dialects.edge._ops import EdgeOpOverload

backends/arm/_passes/fuse_quantized_activation_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
# pyre-unsafe
77

88
import torch
9+
from executorch.backends.arm._passes.quant_args import QuantArgs
910
from executorch.backends.arm.constants import Q_OPS
10-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313
from torch.fx import Node

backends/arm/_passes/insert_rescales_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99

1010
import torch
1111
from executorch.backends.arm._passes.arm_pass_utils import create_node
12+
from executorch.backends.arm._passes.quant_args import QuantArgs
1213
from executorch.backends.arm.constants import DQ_OPS, Q_OPS
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515
from torch import Tensor
1616
from torch.fx import GraphModule, Node

backends/arm/_passes/insert_table_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from executorch.backends.arm._passes.arm_pass_utils import create_node
13-
from executorch.backends.arm.tosa_quant_utils import QuantArgs
13+
from executorch.backends.arm._passes.quant_args import QuantArgs
1414
from executorch.exir import ExportedProgram
1515

1616
from executorch.exir.dialects._ops import ops as exir_ops

0 commit comments

Comments
 (0)