Skip to content

Commit a523a5b

Browse files
author
sidart
committed
Summary: Initial CMSIS-NN custom kernels port (Take #2)
Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent b440e82 commit a523a5b

File tree

10 files changed

+341
-20
lines changed

10 files changed

+341
-20
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ endif()
530530

531531
if(EXECUTORCH_BUILD_CORTEX_M)
532532
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m)
533+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/backends/cortex_m/cmsis-nn/ops)
533534
endif()
534535

535536
if(EXECUTORCH_BUILD_DEVTOOLS)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import torch
8+
from executorch.exir.dialects._ops import (
9+
ops as exir_ops,
10+
) # To provide the implementation of the operators
11+
from torch.library import impl, Library, register_fake
12+
13+
# New operator library with a custom namespace to allow fusion etc.
14+
lib = Library("cortex_m", "DEF")
15+
16+
###
17+
# add.Tensor
18+
###
19+
20+
lib.define("aten_add_tensor(Tensor self, Tensor other, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)")
21+
22+
@impl(lib, "aten_add_tensor", "CompositeExplicitAutograd")
23+
def aten_add_tensor_impl(input1, input2, dtype, out):
24+
return exir_ops.edge.cortex_m.aten_add_tensor.default(input1, input2, dtype, dtype)
25+
26+
27+
###
28+
# add.out
29+
###
30+
31+
lib.define(
32+
"add.out(Tensor input1, Tensor input2, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)"
33+
)
34+
35+
@impl(lib, "add.out", "CompositeExplicitAutograd")
36+
def add_out_impl(
37+
input1: torch.Tensor,
38+
input2: torch.Tensor,
39+
dtype: torch.dtype,
40+
out: torch.Tensor,
41+
) -> torch.Tensor:
42+
"""
43+
The implementation of cmsis-nn add.out.
44+
"""
45+
46+
return exir_ops.edge.cortex_m.add.default(
47+
input1, input2, dtype, dtype
48+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
- op: aten::add.out
8+
kernels:
9+
- arg_meta: null
10+
kernel_name: cortex_m::aten_add_tensor
11+
12+
- op: aten::_softmax.out
13+
kernels:
14+
- arg_meta: null
15+
kernel_name: cortex_m::aten_softmax
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
cmake_minimum_required(VERSION 3.19)
8+
9+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
10+
if(NOT CMAKE_CXX_STANDARD)
11+
set(CMAKE_CXX_STANDARD 17)
12+
endif()
13+
set(CMAKE_VERBOSE_MAKEFILE ON)
14+
15+
# Source root directory for executorch.
16+
if(NOT EXECUTORCH_ROOT)
17+
set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../../../)
18+
endif()
19+
20+
include(${EXECUTORCH_ROOT}/tools/cmake/Utils.cmake)
21+
include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
22+
23+
set(EXECUTORCH_ENABLE_LOGGING ON CACHE BOOL "Enable ExecuTorch logging")
24+
set(EXECUTORCH_LOG_LEVEL "DEBUG" CACHE STRING "ExecuTorch log level")
25+
26+
# Path to CMSIS-NN root - adjust as needed
27+
set(CMSIS_NN_ROOT /home/sidart/working/CMSIS-NN)
28+
29+
# Cortex-M CMSIS ops sources
30+
set(_cortex_m_kernels_cmsis__srcs
31+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_add_tensor.cpp"
32+
"${EXECUTORCH_ROOT}/backends/cortex_m/cmsis-nn/ops/op_aten_softmax.cpp"
33+
)
34+
35+
# Common include directories
36+
set(_common_include_directories
37+
${EXECUTORCH_ROOT}/..
38+
${EXECUTORCH_ROOT}/runtime/core/portable_type/c10
39+
${CMSIS_NN_ROOT}/Include
40+
${CMSIS_NN_ROOT} # For any CMake or config includes
41+
)
42+
43+
# Import CMSIS-NN static library as a target
44+
add_library(cmsis_nn STATIC IMPORTED)
45+
set_target_properties(cmsis_nn PROPERTIES
46+
IMPORTED_LOCATION "${CMSIS_NN_ROOT}/build/libcmsis-nn.a"
47+
INTERFACE_INCLUDE_DIRECTORIES "${CMSIS_NN_ROOT}/Include"
48+
)
49+
50+
# Build cortex_m_cmsis_kernels static library
51+
add_library(cortex_m_cmsis_kernels ${_cortex_m_kernels_cmsis__srcs})
52+
53+
# Include directories for cortex_m_cmsis_kernels
54+
target_include_directories(cortex_m_cmsis_kernels
55+
PRIVATE
56+
${_common_include_directories}
57+
)
58+
59+
# Link libraries: executorch and CMSIS-NN imported target
60+
target_link_libraries(cortex_m_cmsis_kernels
61+
PRIVATE
62+
cmsis_nn
63+
executorch
64+
)
65+
66+
# Generate C++ bindings for kernels and operators
67+
gen_selected_ops(
68+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" OPS_SCHEMA_YAML
69+
"${CMAKE_CURRENT_LIST_DIR}/../cmsis.yaml" "" ""
70+
)
71+
generate_bindings_for_kernels(
72+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" FUNCTIONS_YAML
73+
${CMAKE_CURRENT_SOURCE_DIR}/../cmsis.yaml
74+
)
75+
76+
gen_operators_lib(
77+
LIB_NAME "cortex_m_cmsis_nn_ops_lib" KERNEL_LIBS cortex_m_cmsis_kernels DEPS executorch
78+
)
79+
set(CMAKE_EXE_LINKER_FLAGS "-Wl,--gc-sections")
80+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffunction-sections -fdata-sections")
81+
82+
# Install targets and headers
83+
install(
84+
TARGETS cortex_m_cmsis_kernels cortex_m_cmsis_nn_ops_lib
85+
DESTINATION lib
86+
PUBLIC_HEADER DESTINATION include/executorch/backends/cortex_m/cmsis-nn/ops/
87+
)
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
#include <iostream>
5+
6+
namespace cortex_m {
7+
namespace native {
8+
9+
using Tensor = executorch::aten::Tensor;
10+
using ScalarType = executorch::aten::ScalarType;
11+
using Scalar = executorch::aten::Scalar;
12+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
13+
14+
torch::executor::Tensor& aten_add_tensor(
15+
torch::executor::KernelRuntimeContext& ctx,
16+
const torch::executor::Tensor& input1,
17+
const torch::executor::Tensor& input2,
18+
const torch::executor::Scalar& alpha,
19+
torch::executor::Tensor& out) {
20+
// Your CMSIS-NN optimized implementation here
21+
// Return 'out' tensor as per Executorch kernel signature
22+
std::cout << "add_out kernel called" << std::endl;
23+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
24+
25+
assert(false);
26+
assert(true);
27+
return out;
28+
}
29+
30+
torch::executor::Tensor& add_out(
31+
torch::executor::KernelRuntimeContext& ctx,
32+
const torch::executor::Tensor& input1,
33+
const torch::executor::Tensor& input2,
34+
const torch::executor::Scalar& alpha,
35+
torch::executor::Tensor& out) {
36+
std::cout << "add_out kernel called" << std::endl;
37+
ET_LOG(Info, "xxxxxxxxxx add_out kernel called");
38+
39+
// Ensure input is char type
40+
ET_CHECK_MSG(
41+
input1.scalar_type() == ScalarType::Char,
42+
"input1.scalar_type() %" PRId8 " is not char type",
43+
static_cast<int8_t>(input1.scalar_type()));
44+
45+
ET_CHECK_MSG(
46+
input2.scalar_type() == ScalarType::Char,
47+
"input2.scalar_type() %" PRId8 " is not char type",
48+
static_cast<int8_t>(input2.scalar_type()));
49+
50+
// Check output dtype is float
51+
ET_CHECK_MSG(
52+
out.scalar_type() == ScalarType::Float,
53+
"out.scalar_type() %" PRId8 " is not float",
54+
static_cast<int8_t>(out.scalar_type()));
55+
56+
// Check dtype is int8 (Char)
57+
/*ET_CHECK_MSG(
58+
dtype == ScalarType::Char,
59+
"dtype %" PRId8 " is not int8 (Char)",
60+
static_cast<int8_t>(dtype));*/
61+
62+
assert(false);
63+
64+
return out;
65+
}
66+
67+
} // namespace native
68+
} // namespace cortex_m
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <executorch/runtime/core/portable_type/tensor.h> // for torch::executor::Tensor
3+
#include <executorch/runtime/core/portable_type/scalar.h> // for torch::executor::Scalar
4+
5+
extern "C" {
6+
#include "Include/arm_nnfunctions.h"
7+
}
8+
9+
namespace cortex_m {
10+
namespace native {
11+
12+
using Tensor = torch::executor::Tensor;
13+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
14+
15+
//__attribute__((section(".text_ddr")))
16+
void softmax_wrapper(
17+
const int8_t* input_data,
18+
int rows,
19+
int cols,
20+
int32_t input_mult,
21+
int32_t input_shift,
22+
int32_t diff_min,
23+
int8_t* output_data) {
24+
arm_softmax_s8(
25+
input_data,
26+
rows,
27+
cols,
28+
input_mult,
29+
input_shift,
30+
diff_min,
31+
output_data);
32+
}
33+
34+
torch::executor::Tensor& aten_softmax(
35+
KernelRuntimeContext& context,
36+
const Tensor& self,
37+
int64_t dim,
38+
bool half_to_float,
39+
Tensor& out) {
40+
41+
ET_LOG(Info, "CMSIS-NN softmax kernel called");
42+
const int8_t* input_data = self.data_ptr<int8_t>();
43+
int8_t* output_data = out.data_ptr<int8_t>();
44+
45+
int rows = self.sizes()[0];
46+
int cols = self.sizes()[1];
47+
ET_LOG(Info, "Input shape: %d x %d", rows, cols);
48+
// Quantization params - dummy values for now, refine later
49+
int32_t input_mult = 1 << 4; // or something from qparams
50+
int32_t input_shift = 0;
51+
int32_t diff_min = -128;
52+
53+
softmax_wrapper(
54+
input_data,
55+
rows,
56+
cols,
57+
input_mult,
58+
input_shift,
59+
diff_min,
60+
output_data);
61+
62+
return out;
63+
}
64+
65+
} // namespace native
66+
} // namespace cortex_m

examples/arm/ethos-u-setup/arm-none-eabi-gcc.cmake

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,11 @@ elseif(
7373
OR CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m55(\\+|$)"
7474
OR CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m85(\\+|$)"
7575
)
76-
set(FLOAT hard)
76+
set(FLOAT soft)
7777
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m4(\\+|$)"
7878
OR CMAKE_SYSTEM_PROCESSOR MATCHES "cortex-m7(\\+|$)"
7979
)
80-
set(FLOAT hard)
80+
set(FLOAT soft)
8181
set(FPU_CONFIG "fpv4-sp-d16")
8282
add_compile_options(-mfpu=${FPU_CONFIG})
8383
add_link_options(-mfpu=${FPU_CONFIG})

examples/arm/executor_runner/CMakeLists.txt

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ option(ET_ATOL "Set atol to use for BundleIO testing" OFF)
1313
option(ET_RTOL "Set rtol to use for BundleIO testing" OFF)
1414
option(ET_DUMP_INPUT "Dump input in log" OFF)
1515
option(ET_DUMP_OUTPUT "Dump output in log" ON)
16-
option(FETCH_ETHOS_U_CONTENT "Fetch ethos_u dependencies instead of relying on pre-downloads" ON)
16+
option(FETCH_ETHOS_U_CONTENT "Fetch ethos_u dependencies instead of relying on pre-downloads" OFF)
1717

1818
if(NOT DEFINED ET_PTE_FILE_PATH AND NOT ${SEMIHOSTING})
1919
message(
@@ -539,6 +539,26 @@ set_property(
539539
PROPERTY IMPORTED_LOCATION
540540
"${ET_BUILD_DIR_PATH}/backends/cortex_m/libcortex_m_kernels.a"
541541
)
542+
add_library(cortex_m_cmsis_nn_ops_lib STATIC IMPORTED)
543+
set_property(
544+
TARGET cortex_m_cmsis_nn_ops_lib
545+
PROPERTY IMPORTED_LOCATION
546+
"${ET_BUILD_DIR_PATH}/backends/cortex_m/cmsis-nn/ops/libcortex_m_cmsis_nn_ops_lib.a"
547+
)
548+
add_library(cortex_m_cmsis_kernels STATIC IMPORTED)
549+
set_property(
550+
TARGET cortex_m_cmsis_kernels
551+
PROPERTY IMPORTED_LOCATION
552+
"${ET_BUILD_DIR_PATH}/backends/cortex_m/cmsis-nn/ops/libcortex_m_cmsis_kernels.a"
553+
)
554+
555+
add_library(cmsis_nn STATIC IMPORTED)
556+
set_property(
557+
TARGET cmsis_nn
558+
PROPERTY IMPORTED_LOCATION
559+
"/home/sidart/working/CMSIS-NN/build/libcmsis-nn.a"
560+
)
561+
542562
add_library(extension_runner_util STATIC IMPORTED)
543563
set_property(
544564
TARGET extension_runner_util
@@ -580,11 +600,14 @@ list(APPEND arm_executor_runner_link
580600
"-Wl,--whole-archive"
581601
executorch_delegate_ethos_u
582602
cortex_m_ops_lib
603+
cortex_m_cmsis_nn_ops_lib
583604
quantized_ops_lib
584605
portable_ops_lib
585606
quantized_kernels
586-
cortex_m_kernels
587607
portable_kernels
608+
cortex_m_kernels
609+
cortex_m_cmsis_kernels
610+
cmsis_nn
588611
"-Wl,--no-whole-archive"
589612
-Xlinker -Map=arm_executor_runner.map
590613
)
@@ -674,6 +697,10 @@ if(ET_DUMP_OUTPUT)
674697
target_compile_definitions(arm_executor_runner PUBLIC -DET_DUMP_OUTPUT)
675698
endif()
676699

700+
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -ffunction-sections -fdata-sections -fno-exceptions -fno-unwind-tables")
701+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -ffunction-sections -fdata-sections -fno-exceptions -fno-unwind-tables")
702+
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -Wl,--gc-sections")
703+
677704
# Fixup compilation of retarget.c
678705
if(SEMIHOSTING)
679706
# Remove this when MLBEDSW-8910 is closed.

0 commit comments

Comments
 (0)