Skip to content

Commit eb3fa53

Browse files
committed
Merge remote-tracking branch 'origin/torchtune-update' into split-utils
2 parents 2002fce + b96ef98 commit eb3fa53

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+713
-248
lines changed

.lintrunner.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,13 @@ exclude_patterns = [
220220
'extension/**',
221221
'kernels/optimized/**',
222222
# Justified <functional> include.
223+
'kernels/portable/cpu/op_bitwise*.cpp',
224+
'kernels/portable/cpu/op_eq.cpp',
225+
'kernels/portable/cpu/op_ge.cpp',
226+
'kernels/portable/cpu/op_gt.cpp',
227+
'kernels/portable/cpu/op_le.cpp',
228+
'kernels/portable/cpu/op_lt.cpp',
229+
'kernels/portable/cpu/op_ne.cpp',
223230
'runtime/kernel/thread_parallel_interface.h',
224231
'scripts/**',
225232
'third-party/**',

CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,8 @@ option(EXECUTORCH_USE_DL "Use libdl library" ON)
242242

243243
option(EXECUTORCH_BUILD_CADENCE "Build the Cadence DSP backend" OFF)
244244

245+
option(EXECUTORCH_BUILD_CORTEX_M "Build the Cortex-M backend" OFF)
246+
245247
#
246248
# pthreadpool: build pthreadpool library. Disable on unsupported platforms
247249
#
@@ -717,6 +719,10 @@ if(EXECUTORCH_BUILD_XNNPACK)
717719
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/xnnpack)
718720
endif()
719721

722+
if(EXECUTORCH_BUILD_CORTEX_M)
723+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
724+
endif()
725+
720726
if(EXECUTORCH_BUILD_DEVTOOLS)
721727
if(NOT EXECUTORCH_BUILD_ARM_BAREMETAL)
722728
set(EXECUTORCH_BUILD_EXTENSION_DATA_LOADER

backends/arm/scripts/build_executorch.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ cmake \
129129
-DEXECUTORCH_BUILD_ARM_BAREMETAL=ON \
130130
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
131131
-DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \
132+
-DEXECUTORCH_BUILD_CORTEX_M=ON \
132133
-DEXECUTORCH_ENABLE_LOGGING=ON \
133134
${build_devtools_flags} \
134135
${build_with_etdump_flags} \

backends/arm/scripts/pre-push

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,13 @@
88
# non-interactive mode. "$#" gives the number of positional arguments.
99
[ "$#" -eq 0 ] && is_script_interactive=1 || is_script_interactive=0
1010

11-
RESET='\e[0m'
12-
RED='\e[31m'
13-
GREEN='\e[32m'
14-
YELLOW='\e[33m'
15-
BLUE='\e[34m'
11+
if [ $is_script_interactive -eq 1 ]; then
12+
RESET='\e[0m'
13+
RED='\e[31m'
14+
GREEN='\e[32m'
15+
YELLOW='\e[33m'
16+
BLUE='\e[34m'
17+
fi
1618

1719
INFO="${BLUE}[INFO]${RESET}"
1820
WARNING="${YELLOW}[WARNING]${RESET}"

backends/arm/test/test_arm_baremetal.sh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh
154154
echo "${TEST_SUITE_NAME}: Test ethos-u target Ethos-U85"
155155
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=add
156156
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=mul
157+
158+
# Cortex-M op tests
159+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd --bundleio
160+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio
161+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qops --bundleio --no_delegate --portable_kernels="aten::sub.out,aten::add.out,aten::mul.out"
162+
examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u85-128 --model_name=qops --bundleio
163+
157164
echo "${TEST_SUITE_NAME}: PASS"
158165
}
159166

backends/cadence/aot/fuse_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,9 @@ class FuseTransposeOrPermuteOpPairsPass(FuseOpPairsAcrossBranchesPass):
885885
"""
886886
Fuse transpose or permute op pairs to a single view op.
887887
(transpose or permutation) -> (quant or dequant) -> (transpose or permutation)
888+
This happens when op2(op1) == identity, modulo unitary dimensions.
889+
'unitary dimensions' example: a tensor of shape [1, 5, 30] is equivalent (in memory) to [5, 1, 30]
890+
so transpose(1, 2) then transpose(0, 2) is a pseudo identity and should be fused.
888891
"""
889892

890893
# A list of ops that can be bypassed when looking for a
@@ -908,7 +911,7 @@ def can_fuse_for_chain(
908911
if not super().can_fuse_for_chain(producer, consumer, consumer_op_packets):
909912
return False
910913

911-
# checking that permut2(permut1(identify)) == identity
914+
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
912915
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
913916
ident_dims = list(range(len(input_shape)))
914917
# this mapping helps to handle both transpose and permutations
@@ -918,14 +921,20 @@ def can_fuse_for_chain(
918921
}
919922
in_dims = f[producer.target](producer, ident_dims)
920923
out_dims = f[consumer.target](consumer, in_dims)
921-
return out_dims == ident_dims
924+
# Filtering out unitary dimensions
925+
non_unit_ident_dims = [dim for dim in ident_dims if input_shape[dim] != 1]
926+
non_unit_out_dims = [dim for dim in out_dims if input_shape[dim] != 1]
927+
return non_unit_out_dims == non_unit_ident_dims
922928

923929
def get_fused_node(
924930
self,
925931
producer: torch.fx.Node,
926932
consumer: torch.fx.Node,
927933
graph_module: torch.fx.GraphModule,
928934
) -> torch.fx.Node:
935+
# This step is important because of how we can fuse transpositions that are not perfectly
936+
# reverse one of another but will be fused if there are unitary dimensions.
937+
# The fused operation must have the same output shape as the consumer.
929938
output_shape = consumer.meta["val"].shape
930939
with graph_module.graph.inserting_after(consumer):
931940
view = graph_module.graph.call_function(

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,28 @@ def _create_operator(
584584
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
585585
False,
586586
),
587+
# transpose -> quant -> transpose is not the reverse BUT there is a UNITARY dimension
588+
# so it ends up being the same on memory => fuse
589+
(
590+
True,
591+
[0, 1],
592+
True,
593+
[0, 2],
594+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
595+
True,
596+
[5, 40, 1],
597+
),
598+
# transpose -> quant -> transpose is not the reverse, and unitary dimensions
599+
# don't help => don't fuse
600+
(
601+
True,
602+
[0, 1],
603+
True,
604+
[1, 3],
605+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
606+
False,
607+
[5, 40, 1, 4],
608+
),
587609
# permutation -> quant -> opposite permutation => fuse
588610
(
589611
False,
@@ -622,6 +644,28 @@ def _create_operator(
622644
False,
623645
[4, 4, 4],
624646
),
647+
# permutation -> quant -> a non reverse permutation BUT there is a UNITARY dimension
648+
# so it ends up being the same on memory => fuse
649+
(
650+
False,
651+
[1, 3, 2, 0],
652+
False,
653+
[3, 2, 1, 0],
654+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
655+
True,
656+
[3, 1, 8, 10],
657+
),
658+
# permutation -> quant -> a non reverse permutation, and unitary dimensions
659+
# don't help => don't fuse
660+
(
661+
False,
662+
[1, 3, 2, 0],
663+
False,
664+
[3, 1, 2, 0],
665+
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
666+
False,
667+
[3, 1, 8, 10],
668+
),
625669
# transpose -> quant -> transpose as a permutation => fuse
626670
(
627671
True,

backends/cortex_m/CMakeLists.txt

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Kernel library for Cortex-M operators. Please keep this file formatted by running:
8+
# ~~~
9+
# cmake-format -i CMakeLists.txt
10+
# ~~~
11+
cmake_minimum_required(VERSION 3.19)
12+
13+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
14+
if(NOT CMAKE_CXX_STANDARD)
15+
set(CMAKE_CXX_STANDARD 17)
16+
endif()
17+
18+
# Source root directory for executorch.
19+
if(NOT EXECUTORCH_ROOT)
20+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../..)
21+
endif()
22+
23+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
24+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
25+
26+
if(NOT PYTHON_EXECUTABLE)
27+
resolve_python_executable()
28+
endif()
29+
30+
# Cortex-M ops kernel sources
31+
set(_cortex_m_kernels__srcs
32+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
33+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
34+
)
35+
36+
# Generate C++ bindings to register kernels into Executorch (for runtime).
37+
# Here select all ops in operators.yaml
38+
set(_yaml_file ${CMAKE_CURRENT_LIST_DIR}/ops/operators.yaml)
39+
gen_selected_ops(LIB_NAME "cortex_m_ops_lib" OPS_SCHEMA_YAML "${_yaml_file}")
40+
41+
# Generate bindings for the kernels
42+
generate_bindings_for_kernels(
43+
LIB_NAME "cortex_m_ops_lib" CUSTOM_OPS_YAML "${_yaml_file}"
44+
)
45+
message("Generated files ${gen_command_sources}")
46+
47+
# Build a library for _cortex_m_kernels_srcs
48+
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
49+
target_link_libraries(cortex_m_kernels PRIVATE executorch)
50+
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
51+
52+
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
53+
gen_operators_lib(
54+
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch
55+
)
56+
57+
install(
58+
TARGETS cortex_m_kernels cortex_m_ops_lib
59+
DESTINATION lib
60+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/ops/
61+
)

backends/cortex_m/ops/op_dequantize_per_tensor.cpp

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ namespace {
2929
*/
3030
void check_dequantize_args(
3131
const Tensor& input,
32+
int64_t zero_point,
3233
int64_t quant_min,
3334
int64_t quant_max,
3435
ScalarType dtype,
@@ -39,6 +40,18 @@ void check_dequantize_args(
3940
"input.scalar_type() %" PRId8 " is not char type",
4041
static_cast<int8_t>(input.scalar_type()));
4142

43+
// Check zp range
44+
ET_CHECK_MSG(
45+
zero_point >= quant_min,
46+
"zero_point must be %" PRId64 " <= quant_min %" PRId64,
47+
zero_point,
48+
quant_min);
49+
ET_CHECK_MSG(
50+
zero_point <= quant_max,
51+
"zero_point must be %" PRId64 " >= quant_max %" PRId64,
52+
zero_point,
53+
quant_max);
54+
4255
// Check output dtype is float
4356
ET_CHECK_MSG(
4457
out.scalar_type() == ScalarType::Float,
@@ -73,18 +86,10 @@ void check_dequantize_args(
7386
/**
7487
* Scalar implementation of quantization for a single value.
7588
*/
76-
template <typename K, typename T>
77-
T dequantize_val(
78-
float scale,
79-
int32_t zero_point,
80-
K value,
81-
int64_t quant_min,
82-
int64_t quant_max) {
83-
(void)quant_min;
84-
(void)quant_max;
85-
return static_cast<T>((static_cast<int32_t>(value) - zero_point) * scale);
89+
template <typename Q, typename F>
90+
F dequantize_val(float scale, int32_t zero_point, Q qvalue) {
91+
return static_cast<F>((static_cast<int32_t>(qvalue) - zero_point) * scale);
8692
}
87-
8893
} // namespace
8994

9095
Tensor& dequantize_per_tensor_out(
@@ -106,29 +111,71 @@ Tensor& dequantize_per_tensor_out(
106111
"Failed to resize out Tensor in dequantize_per_tensor_out");
107112

108113
// Validate input parameters
109-
check_dequantize_args(input, quant_min, quant_max, dtype, out);
114+
check_dequantize_args(input, zero_point, quant_min, quant_max, dtype, out);
110115

111-
// Pre-compute inverse scale for better performance
112116
int32_t zp = static_cast<int32_t>(zero_point);
113-
int32_t qmin = static_cast<int32_t>(quant_min);
114-
int32_t qmax = static_cast<int32_t>(quant_max);
115117

116118
// Get pointers to input and output data
117119
const int8_t* input_data = input.const_data_ptr<int8_t>();
118120
float* out_data = out.mutable_data_ptr<float>();
119121
const size_t numel = input.numel();
120122

123+
size_t i = 0;
121124
#if defined(HAS_HELIUM_SIMD)
122-
// Helium MVE implementation for float32 to int8 quantization
123-
#Error "Implement MVE version!"
124-
#else
125-
// Scalar implementation for float32 to int8 quantization
126-
for (size_t i = 0; i < numel; i++) {
127-
out_data[i] =
128-
dequantize_val<int8_t, float>(scale, zp, input_data[i], qmin, qmax);
125+
// Helium MVE implementation for int8 to float quantization
126+
static uint8x16_t voffset{
127+
0x0,
128+
0x8,
129+
0x4,
130+
0xC,
131+
0x1,
132+
0x9,
133+
0x5,
134+
0xD,
135+
0x2,
136+
0xA,
137+
0x6,
138+
0xE,
139+
0x3,
140+
0xB,
141+
0x7,
142+
0xF};
143+
144+
int16x8_t vzp = vdupq_n_s16(static_cast<int16_t>(zp));
145+
float32x4_t vscale = vdupq_n_f32(static_cast<float>(scale));
146+
147+
for (; i + 15 < numel; i += 16) {
148+
int8x16_t in_084C195D2A6E3B7F =
149+
vldrbq_gather_offset_s8(input_data, voffset);
150+
151+
int16x8_t in_04152637 = vsubq_s16(vmovlbq_s8(in_084C195D2A6E3B7F), vzp);
152+
int16x8_t in_8C9DAEBF = vsubq_s16(vmovltq_s8(in_084C195D2A6E3B7F), vzp);
153+
154+
float32x4_t inf_0123 = vcvtq_f32_s32(vmovlbq_s16(in_04152637));
155+
float32x4_t inf_4567 = vcvtq_f32_s32(vmovltq_s16(in_04152637));
156+
float32x4_t inf_89AB = vcvtq_f32_s32(vmovlbq_s16(in_8C9DAEBF));
157+
float32x4_t inf_CDEF = vcvtq_f32_s32(vmovltq_s16(in_8C9DAEBF));
158+
159+
float32x4_t out_0123 = vmulq_f32(inf_0123, vscale);
160+
float32x4_t out_4567 = vmulq_f32(inf_4567, vscale);
161+
float32x4_t out_89AB = vmulq_f32(inf_89AB, vscale);
162+
float32x4_t out_CDEF = vmulq_f32(inf_CDEF, vscale);
163+
164+
vstrwq_f32(out_data + 0, out_0123);
165+
vstrwq_f32(out_data + 4, out_4567);
166+
vstrwq_f32(out_data + 8, out_89AB);
167+
vstrwq_f32(out_data + 12, out_CDEF);
168+
169+
input_data += 16;
170+
out_data += 16;
129171
}
130-
#endif
172+
#endif // defined(HAS_HELIUM_SIMD)
131173

174+
for (; i < numel; i++) {
175+
*out_data = dequantize_val<int8_t, float>(scale, zp, *input_data);
176+
*input_data++;
177+
*out_data++;
178+
}
132179
return out;
133180
}
134181

0 commit comments

Comments
 (0)