Skip to content

Commit c96d19d

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Improve softmax perf when transpose is not needed
Summary: When the supplied dimension is the last dim of the tensor, we don't need to permute anything and can call the nnlib kernel directly. Reviewed By: zonglinpeng Differential Revision: D79514231
1 parent ec35f56 commit c96d19d

File tree

1 file changed

+67
-62
lines changed

1 file changed

+67
-62
lines changed

backends/cadence/hifi/operators/op_softmax.cpp

Lines changed: 67 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -70,81 +70,86 @@ Tensor& _softmax_out(
7070
optimized = false;
7171

7272
if (optimized) {
73-
int* p_inp = (int*)in.const_data_ptr<float>();
74-
int* out_data = (int*)out.mutable_data_ptr<float>();
75-
76-
int num_inp_dims = in.dim();
77-
int num_out_dims = num_inp_dims;
78-
79-
int p_inp_shape[kNnlibMaxDim];
80-
int p_out_shape[kNnlibMaxDim];
81-
int p_permute_vec[kNnlibMaxDim];
82-
83-
for (int i = 0; i < num_inp_dims; i++)
84-
p_inp_shape[i] = in.size(i);
85-
86-
for (int i = 0; i < num_inp_dims; i++) {
87-
if (i == d)
88-
p_permute_vec[i] = num_inp_dims - 1;
89-
else if (i == (num_inp_dims - 1))
90-
p_permute_vec[num_inp_dims - 1] = d;
91-
else
92-
p_permute_vec[i] = i;
93-
94-
p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
95-
96-
if (i != d)
97-
outer_size = outer_size * p_inp_shape[i];
98-
}
73+
if (dim == in.dim() - 1) {
74+
const float* p_inp = in.const_data_ptr<float>();
75+
float* out_data = out.mutable_data_ptr<float>();
76+
77+
WORD32 ret_val = xa_nn_vec_softmax_f32_f32(out_data, p_inp, size);
78+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
79+
} else {
80+
int* p_inp = (int*)in.const_data_ptr<float>();
81+
int* out_data = (int*)out.mutable_data_ptr<float>();
82+
int num_inp_dims = in.dim();
83+
int num_out_dims = num_inp_dims;
84+
85+
int p_inp_shape[kNnlibMaxDim];
86+
int p_out_shape[kNnlibMaxDim];
87+
int p_permute_vec[kNnlibMaxDim];
88+
89+
for (int i = 0; i < num_inp_dims; i++)
90+
p_inp_shape[i] = in.size(i);
91+
92+
for (int i = 0; i < num_inp_dims; i++) {
93+
if (i == d)
94+
p_permute_vec[i] = num_inp_dims - 1;
95+
else if (i == (num_inp_dims - 1))
96+
p_permute_vec[num_inp_dims - 1] = d;
97+
else
98+
p_permute_vec[i] = i;
99+
100+
p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
101+
102+
if (i != d)
103+
outer_size = outer_size * p_inp_shape[i];
104+
}
99105

100-
outer_stride = size;
106+
outer_stride = size;
101107

102-
int* p_out =
103-
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
108+
int* p_out =
109+
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
104110

105-
ET_KERNEL_CHECK(ctx, p_out != nullptr, MemoryAllocationFailed, out);
111+
ET_KERNEL_CHECK(ctx, p_out != nullptr, MemoryAllocationFailed, out);
106112

107-
int* p_out1 =
108-
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
113+
int* p_out1 =
114+
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));
109115

110-
ET_KERNEL_CHECK(ctx, p_out1 != nullptr, MemoryAllocationFailed, out);
116+
ET_KERNEL_CHECK(ctx, p_out1 != nullptr, MemoryAllocationFailed, out);
111117

112-
WORD32 ret_val = xa_nn_transpose_32_32(
113-
p_out,
114-
p_out_shape,
115-
p_inp,
116-
p_inp_shape,
117-
p_permute_vec,
118-
num_out_dims,
119-
num_inp_dims);
118+
WORD32 ret_val = xa_nn_transpose_32_32(
119+
p_out,
120+
p_out_shape,
121+
p_inp,
122+
p_inp_shape,
123+
p_permute_vec,
124+
num_out_dims,
125+
num_inp_dims);
120126

121-
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
127+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
122128

123-
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
124-
size_t outer = outer_idx * outer_stride;
125-
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
126-
size_t base = outer + inner_idx;
129+
for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
130+
size_t outer = outer_idx * outer_stride;
131+
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
132+
size_t base = outer + inner_idx;
127133

128-
float* p_in_data = (float*)&p_out[base];
129-
float* p_out_data = (float*)&p_out1[base];
134+
float* p_in_data = (float*)&p_out[base];
135+
float* p_out_data = (float*)&p_out1[base];
130136

131-
ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);
137+
ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);
132138

133-
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
139+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
140+
}
134141
}
135-
}
136-
137-
ret_val = xa_nn_transpose_32_32(
138-
out_data,
139-
p_inp_shape,
140-
p_out1,
141-
p_out_shape,
142-
p_permute_vec,
143-
num_out_dims,
144-
num_inp_dims);
145-
146-
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
147142

143+
ret_val = xa_nn_transpose_32_32(
144+
out_data,
145+
p_inp_shape,
146+
p_out1,
147+
p_out_shape,
148+
p_permute_vec,
149+
num_out_dims,
150+
num_inp_dims);
151+
ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
152+
}
148153
return out;
149154
}
150155

0 commit comments

Comments
 (0)