Skip to content

Commit 9de4c16

Browse files
psiddhGithub Executorchshoumikhin
authored
Summary: Minor cleanup post quantized_add op (#13824)
Test Plan: examples/arm/run.sh --et_build_root=arm_test/test_run --target=ethos-u55-128 --model_name=qadd2 --no_delegate Reviewers: Subscribers: Tasks: Tags: ### Summary [PLEASE REMOVE] See [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests) for ExecuTorch PR guidelines. [PLEASE REMOVE] If this PR closes an issue, please add a `Fixes #<issue-id>` line. [PLEASE REMOVE] If this PR introduces a fix or feature that should be the upcoming release notes, please add a "Release notes: <area>" label. For a list of available release notes labels, check out [CONTRIBUTING.md's Pull Requests](https://github.com/pytorch/executorch/blob/main/CONTRIBUTING.md#pull-requests). ### Test plan [PLEASE REMOVE] How did you test this PR? Please write down any manual commands you used and note down tests that you have written if applicable. --------- Co-authored-by: Github Executorch <[email protected]> Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent b660c2e commit 9de4c16

File tree

4 files changed

+52
-25
lines changed

4 files changed

+52
-25
lines changed

backends/cortex_m/ops/cortex_m_ops_common.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,6 @@
1717
#include <executorch/runtime/kernel/kernel_includes.h>
1818
#include <executorch/runtime/platform/assert.h>
1919

20-
// Include CMSIS-NN headers with C linkage
21-
extern "C" {
22-
#include "arm_nnfunctions.h"
23-
}
24-
2520
using Tensor = torch::executor::Tensor;
2621
using ScalarType = executorch::aten::ScalarType;
2722
using Scalar = torch::executor::Scalar;
@@ -139,3 +134,19 @@ inline Error resize_to_broadcast_target_size(
139134
return executorch::runtime::resize_tensor(
140135
output, {expected_output_size, expected_output_dim});
141136
}
137+
138+
/**
139+
* Convert Scalar to CMSIS-NN int32 format
140+
* For multipliers, zero_points, etc. from quantize_multiplier_aot
141+
*/
142+
inline int32_t extractScalarToInt32(const Scalar& scalar_value) {
143+
return static_cast<int32_t>(scalar_value.to<int64_t>());
144+
}
145+
146+
/**
147+
* Convert Scalar to CMSIS-NN int format
148+
* For shift values from quantize_multiplier_aot
149+
*/
150+
inline int extractScalarToInt(const Scalar& scalar_value) {
151+
return static_cast<int>(scalar_value.to<int64_t>());
152+
}

backends/cortex_m/ops/op_quantized_add.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88

99
#include "cortex_m_ops_common.h"
1010

11+
// Include CMSIS-NN headers with C linkage
12+
extern "C" {
13+
#include "arm_nnfunctions.h"
14+
}
15+
1116
namespace cortex_m {
1217
namespace native {
1318
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
@@ -54,19 +59,15 @@ Tensor& quantized_add_out(
5459
"quantized_add_out: input1_int8.sizes() = %zu",
5560
input1_int8.sizes().size());
5661

57-
// FIX: Use template types that ExecutorTorch definitely provides
58-
// Use to<int64_t>() and to<double>() which are commonly instantiated
59-
int32_t zp1 = static_cast<int32_t>(input1_zero_point.to<int64_t>());
60-
int32_t input1_mult = static_cast<int32_t>(input1_multiplier.to<int64_t>());
61-
int input1_shift_val = static_cast<int>(input1_shift.to<int64_t>());
62-
63-
int32_t zp2 = static_cast<int32_t>(input2_zero_point.to<int64_t>());
64-
int32_t input2_mult = static_cast<int32_t>(input2_multiplier.to<int64_t>());
65-
int input2_shift_val = static_cast<int>(input2_shift.to<int64_t>());
66-
67-
int32_t out_zp = static_cast<int32_t>(output_zero_point.to<int64_t>());
68-
int32_t output_mult = static_cast<int32_t>(output_multiplier.to<int64_t>());
69-
int output_shift_val = static_cast<int>(output_shift.to<int64_t>());
62+
int32_t zp1 = extractScalarToInt32(input1_zero_point);
63+
int32_t input1_mult = extractScalarToInt32(input1_multiplier);
64+
int input1_shift_val = extractScalarToInt(input1_shift);
65+
int32_t zp2 = extractScalarToInt32(input2_zero_point);
66+
int32_t input2_mult = extractScalarToInt32(input2_multiplier);
67+
int input2_shift_val = extractScalarToInt(input2_shift);
68+
int32_t out_zp = extractScalarToInt32(output_zero_point);
69+
int32_t output_mult = extractScalarToInt32(output_multiplier);
70+
int output_shift_val = extractScalarToInt(output_shift);
7071

7172
// Left shift to maximize precision (tune as needed)
7273
const int32_t left_shift = 20;

examples/arm/aot_arm_compiler.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,12 @@ def get_args():
600600
action="store_false",
601601
help="Disable strict checking while exporting models.",
602602
)
603+
parser.add_argument(
604+
"--enable_qdq_fusion_pass",
605+
action="store",
606+
default=False,
607+
help="Skip the QuantizedOpFusionPass fusion step (default: False)",
608+
)
603609
args = parser.parse_args()
604610

605611
if args.evaluate and (
@@ -791,14 +797,20 @@ def to_edge_no_delegate(exported_program, args, model: torch.nn.Module, example_
791797
return model_int8, edge
792798

793799

794-
def transform_for_cortex_m_backend(edge):
800+
def transform_for_cortex_m_backend(edge, args):
795801
# Let's make sure we are using optimized Cortex M backend
796802
# NB: If we can't find and replace ops those are expected to be replaced,
797803
# bad things will happen at runtime, like "missing operator" errors!
798-
# Instantiate the pass
799-
replace_quant_pass = ReplaceQuantNodesPass()
800-
quantized_op_fusion_pass = QuantizedOpFusionPass()
801-
edge = edge.transform([replace_quant_pass, quantized_op_fusion_pass])
804+
805+
# Instantiate the mandatory ReplaceQuantNodesPass
806+
passes = [ReplaceQuantNodesPass()]
807+
808+
# Conditionally add the QuantizedOpFusionPass
809+
if args.enable_qdq_fusion_pass.lower() == "true":
810+
passes.append(QuantizedOpFusionPass())
811+
812+
# Apply the passes
813+
edge = edge.transform(passes)
802814

803815
return edge
804816

@@ -835,7 +847,7 @@ def transform_for_cortex_m_backend(edge):
835847
)
836848

837849
# Transform so we can use ops from the Cortex M backend
838-
edge = transform_for_cortex_m_backend(edge)
850+
edge = transform_for_cortex_m_backend(edge, args)
839851

840852
dump_delegation_info(edge, args.intermediates)
841853

examples/arm/run.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ ethos_u_scratch_dir=${script_dir}/ethos-u-scratch
4040
scratch_dir_set=false
4141
toolchain=arm-none-eabi-gcc
4242
select_ops_list="aten::_softmax.out"
43+
qdq_fusion_op=false
4344

4445
function help() {
4546
echo "Usage: $(basename $0) [options]"
@@ -69,6 +70,7 @@ function help() {
6970
echo " --pte_placement=<elf|ADDR> Ethos-U: Control if runtime has PTE baked into the elf or if its placed in memory outside of the elf, defaults to ${pte_placement}"
7071
echo " --et_build_root=<FOLDER> Executorch build output root folder to use, defaults to ${et_build_root}"
7172
echo " --scratch-dir=<FOLDER> Path to your Ethos-U scrach dir if you not using default ${ethos_u_scratch_dir}"
73+
echo " --qdq_fusion_op=<true/false> Enable/Disable QDQ fusion op"
7274
exit 0
7375
}
7476

@@ -96,6 +98,7 @@ for arg in "$@"; do
9698
--pte_placement=*) pte_placement="${arg#*=}";;
9799
--et_build_root=*) et_build_root="${arg#*=}";;
98100
--scratch-dir=*) ethos_u_scratch_dir="${arg#*=}" ; scratch_dir_set=true ;;
101+
--qdq_fusion_op=*) qdq_fusion_op="${arg#*=}";;
99102
*)
100103
;;
101104
esac
@@ -275,7 +278,7 @@ for i in "${!test_model[@]}"; do
275278
model_compiler_flags="${model_compiler_flags} --model_input=${model_input}"
276279
fi
277280

278-
ARM_AOT_CMD="python3 -m examples.arm.aot_arm_compiler --model_name=${model} --target=${target} ${model_compiler_flags} --intermediate=${output_folder} --output=${pte_file} --system_config=${system_config} --memory_mode=${memory_mode} $bundleio_flag ${etrecord_flag} --config=${config}"
281+
ARM_AOT_CMD="python3 -m examples.arm.aot_arm_compiler --model_name=${model} --target=${target} ${model_compiler_flags} --intermediate=${output_folder} --output=${pte_file} --system_config=${system_config} --memory_mode=${memory_mode} $bundleio_flag ${etrecord_flag} --config=${config} --enable_qdq_fusion_pass=${qdq_fusion_op}"
279282
echo "CALL ${ARM_AOT_CMD}" >&2
280283
${ARM_AOT_CMD} 1>&2
281284

0 commit comments

Comments
 (0)