Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 32 additions & 4 deletions backends/cadence/hifi/operators/op_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ Tensor& _softmax_out(
if (optimized) {
int* p_inp = (int*)in.const_data_ptr<float>();
int* out_data = (int*)out.mutable_data_ptr<float>();

int num_inp_dims = in.dim();
int num_out_dims = num_inp_dims;

Expand All @@ -99,6 +98,37 @@ Tensor& _softmax_out(

outer_stride = size;

WORD32 ret_val = 0;

// Check if the input is permuted. If not, then we don't need to transpose
bool is_permuted = false;
for (int i = 0; i < num_inp_dims; i++) {
if (p_permute_vec[i] != i) {
is_permuted = true;
break;
}
}

if (!is_permuted) {
const float* p_inpf = in.const_data_ptr<float>();
float* out_dataf = out.mutable_data_ptr<float>();

for (size_t outer_idx = 0; outer_idx < outer_size; ++outer_idx) {
size_t outer = outer_idx * outer_stride;
for (size_t inner_idx = 0; inner_idx < stride; ++inner_idx) {
size_t base = outer + inner_idx;

float* p_in_data = (float*)&p_inpf[base];
float* p_out_data = (float*)&out_dataf[base];

ret_val = xa_nn_vec_softmax_f32_f32(p_out_data, p_in_data, size);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);
}
}
return out;
}

int* p_out =
(int*)kernels::allocate_temp_memory(ctx, out.numel() * sizeof(int));

Expand All @@ -109,7 +139,7 @@ Tensor& _softmax_out(

ET_KERNEL_CHECK(ctx, p_out1 != nullptr, MemoryAllocationFailed, out);

WORD32 ret_val = xa_nn_transpose_32_32(
ret_val = xa_nn_transpose_32_32(
p_out,
p_out_shape,
p_inp,
Expand Down Expand Up @@ -142,9 +172,7 @@ Tensor& _softmax_out(
p_permute_vec,
num_out_dims,
num_inp_dims);

ET_KERNEL_CHECK(ctx, ret_val == 0, Internal, out);

return out;
}

Expand Down
Loading