Skip to content

Commit 362f3f7

Browse files
sidartGithub Executorch
authored andcommitted
Summary: Initial CMSS-NN integration for Quantized Add Op
Test Plan: a) Setup for Arm FVP and run 'examples/arm/run.sh' (Check no regressions in e2e test scenarios) b) Then add to run.sh another iteration with qadd with only --quantize flag and see that quantized add op is called c) cd backends/cortex_m/test/; python test_quantize_add_fusion_pass.py ---------------------------------------------------------------------- Ran 8 tests in 11.128s OK Reviewers: Subscribers: Tasks: Tags:
1 parent fc87462 commit 362f3f7

16 files changed

+1145
-107
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,41 @@ endif()
2424

2525
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
2626
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
27+
include(ExternalProject)
28+
29+
# Download and build CMSIS-NN from GitHub
30+
set(CMSIS_NN_VERSION
31+
"v4.1.0"
32+
CACHE STRING "CMSIS-NN version to download"
33+
)
34+
set(CMSIS_NN_ROOT ${CMAKE_CURRENT_BINARY_DIR}/cmsis-nn)
35+
set(CMSIS_NN_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR}/cmsis-nn-build)
36+
set(CMSIS_NN_LIB_PATH ${CMSIS_NN_BINARY_DIR}/libcmsis-nn.a)
37+
38+
set(TARGET_CPU
39+
"cortex-m55"
40+
CACHE STRING "Target CPU for CMSIS-NN build"
41+
)
42+
ExternalProject_Add(
43+
cmsis_nn_external
44+
GIT_REPOSITORY https://github.com/ARM-software/CMSIS-NN.git
45+
GIT_TAG ${CMSIS_NN_VERSION}
46+
SOURCE_DIR ${CMSIS_NN_ROOT}
47+
BINARY_DIR ${CMSIS_NN_BINARY_DIR}
48+
CMAKE_ARGS
49+
-DCMAKE_TOOLCHAIN_FILE=${EXECUTORCH_ROOT}/examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake
50+
-DTARGET_CPU=${TARGET_CPU}
51+
-DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE}
52+
BUILD_COMMAND ${CMAKE_COMMAND} --build <BINARY_DIR> --parallel
53+
INSTALL_COMMAND ""
54+
BUILD_BYPRODUCTS ${CMSIS_NN_LIB_PATH}
55+
)
2756

2857
# Cortex-M ops kernel sources
2958
set(_cortex_m_kernels__srcs
3059
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
3160
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
61+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantized_add.cpp
3262
)
3363

3464
# Generate C++ bindings to register kernels into Executorch (for runtime). Here
@@ -44,9 +74,23 @@ message("Generated files ${gen_command_sources}")
4474

4575
# Build a library for _cortex_m_kernels_srcs
4676
add_library(cortex_m_kernels ${_cortex_m_kernels__srcs})
47-
target_link_libraries(cortex_m_kernels PRIVATE executorch)
4877
target_compile_options(cortex_m_kernels PUBLIC ${_common_compile_options})
4978

79+
# Add dependency on CMSIS-NN external project
80+
add_dependencies(cortex_m_kernels cmsis_nn_external)
81+
82+
# Set include directories - Include is directly in CMSIS-NN root
83+
target_include_directories(
84+
cortex_m_kernels
85+
PRIVATE ${EXECUTORCH_ROOT}/..
86+
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
87+
$<BUILD_INTERFACE:${CMSIS_NN_ROOT}/Include>
88+
$<BUILD_INTERFACE:${CMSIS_NN_ROOT}>
89+
)
90+
91+
# Link against the CMSIS-NN static library directly
92+
target_link_libraries(cortex_m_kernels PUBLIC ${CMSIS_NN_LIB_PATH} executorch)
93+
5094
# cortex_m_ops_lib: Register Cortex-M ops kernels into Executorch runtime
5195
gen_operators_lib(
5296
LIB_NAME "cortex_m_ops_lib" KERNEL_LIBS cortex_m_kernels DEPS executorch

backends/cortex_m/ops/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ python_library(
1616
],
1717
deps = [
1818
"fbcode//caffe2:torch",
19+
"//executorch/backends/cortex_m/passes:passes_utils",
1920
],
2021
)
2122

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
12+
#include <executorch/runtime/kernel/kernel_includes.h>
13+
14+
// Include CMSIS-NN headers with C linkage
15+
extern "C" {
16+
#include "arm_nnfunctions.h"
17+
}
18+
19+
using Tensor = torch::executor::Tensor;
20+
using ScalarType = executorch::aten::ScalarType;
21+
using Scalar = torch::executor::Scalar;
22+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
23+
using Error = executorch::runtime::Error;
24+
25+
inline void validate_quantized_inputs(
26+
KernelRuntimeContext& context,
27+
const Tensor& input1,
28+
const Tensor& input2,
29+
Tensor& output) {
30+
ET_CHECK_MSG(input1.scalar_type() == ScalarType::Char, "Input1 must be int8");
31+
ET_CHECK_MSG(input2.scalar_type() == ScalarType::Char, "Input2 must be int8");
32+
ET_CHECK_MSG(output.scalar_type() == ScalarType::Char, "Output must be int8");
33+
ET_CHECK_MSG(
34+
input1.sizes() == input2.sizes(), "Input tensors must be the same shape");
35+
ET_CHECK_MSG(
36+
input1.scalar_type() == input2.scalar_type(),
37+
"Input tensors must be the same dtype");
38+
ET_CHECK_MSG(
39+
(torch::executor::resize_to_broadcast_target_size(
40+
input1, input2, output) == Error::Ok),
41+
"Broadcast error: resize_to_broadcast_target_size failed");
42+
return;
43+
}
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "cortex_m_ops_common.h"
10+
11+
namespace cortex_m {
12+
namespace native {
13+
14+
Tensor& quantized_add_out(
15+
KernelRuntimeContext& context,
16+
const Tensor& input1_int8,
17+
const Scalar& input1_zero_point,
18+
const Scalar& input1_multiplier,
19+
const Scalar& input1_shift,
20+
const Tensor& input2_int8,
21+
const Scalar& input2_zero_point,
22+
const Scalar& input2_multiplier,
23+
const Scalar& input2_shift,
24+
const Scalar& output_zero_point,
25+
const Scalar& output_multiplier,
26+
const Scalar& output_shift,
27+
Tensor& out) {
28+
validate_quantized_inputs(context, input1_int8, input2_int8, out);
29+
30+
ET_LOG(
31+
Info,
32+
"quantized_add_out: input1_int8.sizes() = %zu",
33+
input1_int8.sizes().size());
34+
35+
// FIX: Use template types that ExecutorTorch definitely provides
36+
// Use to<int64_t>() and to<double>() which are commonly instantiated
37+
int32_t zp1 = static_cast<int32_t>(input1_zero_point.to<int64_t>());
38+
int32_t input1_mult = static_cast<int32_t>(input1_multiplier.to<int64_t>());
39+
int input1_shift_val = static_cast<int>(input1_shift.to<int64_t>());
40+
41+
int32_t zp2 = static_cast<int32_t>(input2_zero_point.to<int64_t>());
42+
int32_t input2_mult = static_cast<int32_t>(input2_multiplier.to<int64_t>());
43+
int input2_shift_val = static_cast<int>(input2_shift.to<int64_t>());
44+
45+
int32_t out_zp = static_cast<int32_t>(output_zero_point.to<int64_t>());
46+
int32_t output_mult = static_cast<int32_t>(output_multiplier.to<int64_t>());
47+
int output_shift_val = static_cast<int>(output_shift.to<int64_t>());
48+
49+
// Left shift to maximize precision (tune as needed)
50+
const int32_t left_shift = 20;
51+
const int32_t activation_min = std::numeric_limits<int8_t>::min();
52+
const int32_t activation_max = std::numeric_limits<int8_t>::max();
53+
54+
// Resize output tensor to match input shape
55+
auto err = torch::executor::resize_tensor(out, input1_int8.sizes());
56+
if (err != executorch::runtime::Error::Ok) {
57+
ET_LOG(
58+
Error,
59+
"quantized_add_out: resize_tensor failed with error code [%d]",
60+
static_cast<int>(err));
61+
std::memset(out.mutable_data_ptr<int8_t>(), 0, out.nbytes());
62+
return out;
63+
}
64+
65+
ET_LOG(
66+
Info,
67+
"Using AoT-computed parameters: input1[mult=%d, shift=%d], input2[mult=%d, shift=%d], output[mult=%d, shift=%d]",
68+
input1_mult,
69+
input1_shift_val,
70+
input2_mult,
71+
input2_shift_val,
72+
output_mult,
73+
output_shift_val);
74+
75+
// Call CMSIS-NN kernel with precomputed parameters
76+
arm_cmsis_nn_status status = arm_elementwise_add_s8(
77+
input1_int8.const_data_ptr<int8_t>(),
78+
input2_int8.const_data_ptr<int8_t>(),
79+
static_cast<int32_t>(zp1),
80+
input1_mult,
81+
input1_shift_val,
82+
static_cast<int32_t>(zp2),
83+
input2_mult,
84+
input2_shift_val,
85+
left_shift,
86+
out.mutable_data_ptr<int8_t>(),
87+
static_cast<int32_t>(out_zp),
88+
output_mult,
89+
output_shift_val,
90+
static_cast<int32_t>(out.numel()),
91+
activation_min,
92+
activation_max);
93+
94+
if (status != ARM_CMSIS_NN_SUCCESS) {
95+
ET_LOG(
96+
Error,
97+
"quantized_add_out: arm_elementwise_add_s8 failed with status [%d]",
98+
status);
99+
std::memset(out.mutable_data_ptr<int8_t>(), 0, out.nbytes());
100+
} else {
101+
ET_LOG(
102+
Info,
103+
"quantized_add_out: Successfully completed with AoT-computed parameters!");
104+
}
105+
106+
return out;
107+
}
108+
109+
// Stub Implementation: Non-out variant for compatibility (functional variant)
110+
// EXIR/ExecuTorch runs an out-variant pass that converts
111+
// .default operations to .out variants before memory planning.
112+
// In the pass we are calling quantized_add's default variant
113+
// but ExecuTorch's kernel dispatch mechanism will end up calling the out
114+
// variant. This stub is to make sure that compiler doesn't complain.
115+
Tensor quantized_add(
116+
KernelRuntimeContext& context,
117+
const Tensor& input1_int8,
118+
const Scalar& input1_zero_point,
119+
const Scalar& input1_multiplier,
120+
const Scalar& input1_shift,
121+
const Tensor& input2_int8,
122+
const Scalar& input2_zero_point,
123+
const Scalar& input2_multiplier,
124+
const Scalar& input2_shift,
125+
const Scalar& output_zero_point,
126+
const Scalar& output_multiplier,
127+
const Scalar& output_shift) {
128+
ET_LOG(Info, "quantized_add: input1_int8.sizes() = %zu", input1_int8.sizes());
129+
130+
// Crash on Debug builds if invoked
131+
assert(False);
132+
// This is to make sure compiler doesn't complain.
133+
return const_cast<Tensor&>(input1_int8);
134+
}
135+
136+
} // namespace native
137+
} // namespace cortex_m

backends/cortex_m/ops/operators.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,14 @@
77
import torch
88
from executorch.exir.dialects._ops import (
99
ops as exir_ops,
10-
) # To provide the implementation of the operators
10+
)
11+
# To provide the implementation of the operators
1112
from torch.library import impl, Library, register_fake
1213

14+
from executorch.backends.cortex_m.passes.passes_utils import (
15+
dequantize_tensor, quantize_tensor,
16+
)
17+
1318
# New operator library with a custom namespace to allow fusion etc.
1419
lib = Library("cortex_m", "DEF")
1520

@@ -96,3 +101,107 @@ def dequantize_per_tensor_impl(
96101
return exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default(
97102
input, scale, zero_point, quant_min, quant_max, dtype
98103
)
104+
105+
106+
# Define the operator schema with multipliers and shifts (11 args)
107+
lib.define(
108+
"quantized_add("
109+
"Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, "
110+
"Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, "
111+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor"
112+
)
113+
114+
@register_fake("cortex_m::quantized_add")
115+
def quantized_add_meta(
116+
self: torch.Tensor,
117+
self_zero_point: int,
118+
self_multiplier: int,
119+
self_shift: int,
120+
other: torch.Tensor,
121+
other_zero_point: int,
122+
other_multiplier: int,
123+
other_shift: int,
124+
output_zero_point: int,
125+
output_multiplier: int,
126+
output_shift: int,
127+
) -> torch.Tensor:
128+
return torch.empty_like(self, dtype=torch.int8)
129+
130+
@impl(lib, "quantized_add", "CompositeExplicitAutograd")
131+
def quantized_add_impl(
132+
self: torch.Tensor,
133+
self_zero_point: int,
134+
self_multiplier: int,
135+
self_shift: int,
136+
other: torch.Tensor,
137+
other_zero_point: int,
138+
other_multiplier: int,
139+
other_shift: int,
140+
output_zero_point: int,
141+
output_multiplier: int,
142+
output_shift: int,
143+
) -> torch.Tensor:
144+
self_fp = dequantize_tensor(self, self_zero_point, self_multiplier, self_shift)
145+
other_fp = dequantize_tensor(other, other_zero_point, other_multiplier, other_shift)
146+
result_fp = self_fp + other_fp
147+
result_quantized = quantize_tensor(result_fp, output_zero_point, output_multiplier, output_shift)
148+
return result_quantized
149+
150+
# Define the operator schema with multipliers and shifts (11 args + out tensor)
151+
lib.define(
152+
"quantized_add.out("
153+
"Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, "
154+
"Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, "
155+
"Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, "
156+
"*, Tensor(a!) out) -> Tensor(a!)"
157+
)
158+
159+
160+
# Fake meta function for shape and dtype inference during compilation
161+
@register_fake("cortex_m::quantized_add.out")
162+
def quantized_add_out_meta(
163+
self: torch.Tensor,
164+
self_zero_point: int,
165+
self_multiplier: int,
166+
self_shift: int,
167+
other: torch.Tensor,
168+
other_zero_point: int,
169+
other_multiplier: int,
170+
other_shift: int,
171+
output_zero_point: int,
172+
output_multiplier: int,
173+
output_shift: int,
174+
out: torch.Tensor,
175+
) -> torch.Tensor:
176+
# Validate shape compatibility if needed
177+
assert out.shape == self.shape, "Output shape must match input shape"
178+
# Output dtype is int8
179+
return out
180+
181+
182+
# Actual implementation delegating to backend or custom kernel
183+
@impl(lib, "quantized_add.out", "CompositeExplicitAutograd")
184+
def quantized_add_out_impl(
185+
self: torch.Tensor,
186+
self_zero_point: int,
187+
self_multiplier: int,
188+
self_shift: int,
189+
other: torch.Tensor,
190+
other_zero_point: int,
191+
other_multiplier: int,
192+
other_shift: int,
193+
output_zero_point: int,
194+
output_multiplier: int,
195+
output_shift: int,
196+
*,
197+
out: torch.Tensor,
198+
) -> torch.Tensor:
199+
self_fp = dequantize_tensor(self, self_zero_point, self_multiplier, self_shift)
200+
other_fp = dequantize_tensor(other, other_zero_point, other_multiplier, other_shift)
201+
result_fp = self_fp + other_fp
202+
result_quantized = quantize_tensor(result_fp, output_zero_point, output_multiplier, output_shift)
203+
204+
# Write into the provided output tensor
205+
out.copy_(result_quantized)
206+
207+
return out

backends/cortex_m/ops/operators.yaml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,15 @@
1515
kernels:
1616
- arg_meta: null
1717
kernel_name: cortex_m::dequantize_per_tensor_out
18+
19+
- func: cortex_m::quantized_add(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift) -> Tensor
20+
variants: function
21+
kernels:
22+
- arg_meta: null
23+
kernel_name: cortex_m::quantized_add
24+
25+
- func: cortex_m::quantized_add.out(Tensor self, Scalar self_zero_point, Scalar self_multiplier, Scalar self_shift, Tensor other, Scalar other_zero_point, Scalar other_multiplier, Scalar other_shift, Scalar output_zero_point, Scalar output_multiplier, Scalar output_shift, *, Tensor(a!) out) -> Tensor(a!)
26+
variants: function
27+
kernels:
28+
- arg_meta: null
29+
kernel_name: cortex_m::quantized_add_out

0 commit comments

Comments
 (0)