Skip to content

Commit f6cc262

Browse files
author
sidart
committed
Summary: Initial CMSS-NN Add Op
Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a1b35e8 commit f6cc262

File tree

4 files changed

+53
-1
lines changed

4 files changed

+53
-1
lines changed

backends/cortex_m/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ include(${EXECUTORCH_ROOT}/tools/cmake/Codegen.cmake)
2727
set(_cortex_m_kernels__srcs
2828
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_quantize_per_tensor.cpp
2929
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_dequantize_per_tensor.cpp
30+
${CMAKE_CURRENT_SOURCE_DIR}/ops/op_add.cpp
3031
)
3132

3233
# Generate C++ bindings to register kernels into Executorch (for runtime).

backends/cortex_m/ops/op_add.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#include <executorch/runtime/kernel/kernel_includes.h>
2+
#include <cinttypes>
3+
namespace cortex_m {
4+
namespace native {
5+
6+
using Tensor = executorch::aten::Tensor;
7+
using ScalarType = executorch::aten::ScalarType;
8+
using KernelRuntimeContext = torch::executor::KernelRuntimeContext;
9+
10+
Tensor& add_out(
11+
KernelRuntimeContext& ctx,
12+
const Tensor& input1,
13+
const Tensor& input2,
14+
const ScalarType dtype,
15+
Tensor& out) {
16+
17+
// Ensure input is char type
18+
ET_CHECK_MSG(
19+
input1.scalar_type() == ScalarType::Char,
20+
"input1.scalar_type() %" PRId8 " is not char type",
21+
static_cast<int8_t>(input1.scalar_type()));
22+
23+
ET_CHECK_MSG(
24+
input2.scalar_type() == ScalarType::Char,
25+
"input2.scalar_type() %" PRId8 " is not char type",
26+
static_cast<int8_t>(input2.scalar_type()));
27+
28+
// Check output dtype is float
29+
ET_CHECK_MSG(
30+
out.scalar_type() == ScalarType::Float,
31+
"out.scalar_type() %" PRId8 " is not float",
32+
static_cast<int8_t>(out.scalar_type()));
33+
34+
// Check dtype is int8 (Char)
35+
ET_CHECK_MSG(
36+
dtype == ScalarType::Char,
37+
"dtype %" PRId8 " is not int8 (Char)",
38+
static_cast<int8_t>(dtype));
39+
40+
41+
return out;
42+
}
43+
44+
} // namespace native
45+
} // namespace cortex_m

backends/cortex_m/ops/operators.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,9 @@
1515
kernels:
1616
- arg_meta: null
1717
kernel_name: cortex_m::dequantize_per_tensor_out
18+
19+
- func: cortex_m::add.out(Tensor a, Tensor b, Scalar alpha, *, Tensor(a!) out) -> Tensor(a!)
20+
variants: function
21+
kernels:
22+
- arg_meta: null
23+
kernel_name: cortex_m::add_out

runtime/core/named_data_map.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ namespace ET_RUNTIME_NAMESPACE {
2727
* Interface to access and retrieve data via name.
2828
* See executorch/extension/flat_tensor/ for an example.
2929
*/
30-
class ET_EXPERIMENTAL NamedDataMap {
30+
class NamedDataMap {
3131
public:
3232
virtual ~NamedDataMap() = default;
3333
/**

0 commit comments

Comments
 (0)