Skip to content

Commit 579b91e

Browse files
authored
Add support for strongly typed softmax
Differential Revision: D81172654 Pull Request resolved: #13750
1 parent f8156fb commit 579b91e

File tree

4 files changed

+182
-0
lines changed

4 files changed

+182
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,12 @@
448448
"roi_align_box_processor(Tensor rois, int output_size_h, int output_size_w, "
449449
"int sampling_ratio, bool aligned) -> (Tensor out)"
450450
)
451+
lib.define(
452+
"_softmax_f32_f32(Tensor self, int dim, bool? half_to_float) -> (Tensor out)"
453+
)
454+
lib.define(
455+
"_softmax_f32_f32.out(Tensor self, int dim, bool? half_to_float, *, Tensor(a!) out) -> Tensor(a!)"
456+
)
451457

452458
# Custom ops with aten namespace. Need to specify the lib var as FRAGMENT type as aten library is already defined
453459
aten_lib = Library("aten", "FRAGMENT")
@@ -2075,3 +2081,13 @@ def roi_align_box_processor_meta(
20752081
aligned: bool,
20762082
) -> torch.Tensor:
20772083
return rois.new_empty((rois.shape[0], 80), dtype=torch.uint8)
2084+
2085+
2086+
@register_fake("cadence::_softmax_f32_f32")
2087+
def softmax_f32_f32_meta(
2088+
self: torch.Tensor,
2089+
dim: int,
2090+
dtype: torch.dtype,
2091+
half_to_float: Optional[bool] = None,
2092+
) -> torch.Tensor:
2093+
return self.new_empty(self.size(), dtype=self.dtype)

backends/cadence/aot/type_dispatch.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,13 @@ class CompileTimeTypeDispatchPass(ExportPass):
9393
},
9494
weight_arg_idx=3,
9595
),
96+
exir_ops.edge.aten._softmax.default: OpConfig(
97+
"_softmax",
98+
type_dispatch_suffixes={
99+
(torch.float32,): "f32_f32",
100+
},
101+
variant="default",
102+
),
96103
}
97104

98105
def call_operator(
Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
using executorch::aten::ScalarType;
13+
using executorch::aten::Tensor;
14+
using executorch::runtime::KernelRuntimeContext;
15+
using torch::executor::Error;
16+
17+
namespace cadence {
18+
namespace impl {
19+
namespace HiFi {
20+
namespace native {
21+
22+
inline Tensor& _softmax_f32_f32_out(
23+
KernelRuntimeContext& ctx,
24+
const Tensor& in,
25+
int64_t dim,
26+
::executorch::aten::optional<bool> half_to_float,
27+
Tensor& out) {
28+
constexpr int kNnlibMaxDim = 16;
29+
30+
const std::optional<int64_t>& dim_t = dim;
31+
const size_t d = ET_NORMALIZE_IX(dim_t.value(), in.dim());
32+
const size_t size = in.size(d);
33+
34+
size_t stride = 1, outer_size = 1;
35+
36+
size_t outer_stride = 1;
37+
38+
int* p_inp = (int*)in.const_data_ptr<float>();
39+
int* out_data = (int*)out.mutable_data_ptr<float>();
40+
41+
int num_inp_dims = in.dim();
42+
int num_out_dims = num_inp_dims;
43+
44+
int p_inp_shape[kNnlibMaxDim];
45+
int p_out_shape[kNnlibMaxDim];
46+
int p_permute_vec[kNnlibMaxDim];
47+
48+
for (int i = 0; i < num_inp_dims; i++)
49+
p_inp_shape[i] = in.size(i);
50+
for (int i = 0; i < num_inp_dims; i++) {
51+
if (i == d)
52+
p_permute_vec[i] = num_inp_dims - 1;
53+
else if (i == (num_inp_dims - 1))
54+
p_permute_vec[num_inp_dims - 1] = d;
55+
else
56+
p_permute_vec[i] = i;
57+
58+
p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
59+
60+
if (i != d)
61+
outer_size = outer_size * p_inp_shape[i];
62+
}
63+
64+
outer_stride = size;
65+
66+
WORD32 ret_val = 0;
67+
68+
// Check if the input is permuted. If not, then we don't need to transpose
69+
bool is_permuted = false;
70+
for (int i = 0; i < num_inp_dims; i++) {
71+
if (p_permute_vec[i] != i) {
72+
is_permuted = true;
73+
break;
74+
}
75+
}
76+
77+
if (!is_permuted) {
78+
const float* p_inpf = in.const_data_ptr<float>();
79+
float* out_dataf = out.mutable_data_ptr<float>();
80+
81+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
82+
size_t outer = outer_idx * outer_stride;
83+
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
84+
size_t base = outer + inner_idx;
85+
86+
float* p_in_data = (float*)&p_inpf[base];
87+
float* p_out_data = (float*)&out_dataf[base];
88+
89+
ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);
90+
91+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
92+
}
93+
}
94+
return out;
95+
}
96+
97+
int* p_out =
98+
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
99+
100+
ET_KERNEL_CHECK(ctx, p_out != nullptr, MemoryAllocationFailed, out);
101+
102+
int* p_out1 =
103+
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
104+
105+
ET_KERNEL_CHECK(ctx, p_out1 != nullptr, MemoryAllocationFailed, out);
106+
107+
ret_val = xa_nn_transpose_32_32(
108+
p_out,
109+
p_out_shape,
110+
p_inp,
111+
p_inp_shape,
112+
p_permute_vec,
113+
num_out_dims,
114+
num_inp_dims);
115+
116+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
117+
118+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
119+
size_t outer = outer_idx * outer_stride;
120+
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
121+
size_t base = outer + inner_idx;
122+
123+
float* p_in_data = (float*)&p_out[base];
124+
float* p_out_data = (float*)&p_out1[base];
125+
126+
ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);
127+
128+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
129+
}
130+
}
131+
132+
ret_val = xa_nn_transpose_32_32(
133+
out_data,
134+
p_inp_shape,
135+
p_out1,
136+
p_out_shape,
137+
p_permute_vec,
138+
num_out_dims,
139+
num_inp_dims);
140+
141+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
142+
143+
return out;
144+
}
145+
146+
Tensor& softmax_f32_f32_out(
147+
KernelRuntimeContext& ctx,
148+
const Tensor& in,
149+
int64_t dim,
150+
::executorch::aten::optional<bool> half_to_float,
151+
Tensor& out) {
152+
return _softmax_f32_f32_out(ctx, in, dim, half_to_float, out);
153+
}
154+
155+
} // namespace native
156+
} // namespace HiFi
157+
} // namespace impl
158+
} // namespace cadence

backends/cadence/hifi/operators/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ OPERATORS = [
9797
"sigmoid",
9898
"slice_copy",
9999
"softmax",
100+
"softmax_f32_f32",
100101
"split_with_sizes_copy",
101102
"sub",
102103
"tanh",

0 commit comments

Comments
 (0)