@@ -72,7 +72,6 @@ Tensor& _softmax_out(
7272 if (optimized) {
7373 int * p_inp = (int *)in.const_data_ptr <float >();
7474 int * out_data = (int *)out.mutable_data_ptr <float >();
75-
7675 int num_inp_dims = in.dim ();
7776 int num_out_dims = num_inp_dims;
7877
@@ -99,6 +98,37 @@ Tensor& _softmax_out(
9998
10099 outer_stride = size;
101100
101+ WORD32 ret_val = 0 ;
102+
103+ // Check if the input is permuted. If not, then we don't need to transpose
104+ bool is_permuted = false ;
105+ for (int i = 0 ; i < num_inp_dims; i++) {
106+ if (p_permute_vec[i] != i) {
107+ is_permuted = true ;
108+ break ;
109+ }
110+ }
111+
112+ if (!is_permuted) {
113+ const float * p_inpf = in.const_data_ptr <float >();
114+ float * out_dataf = out.mutable_data_ptr <float >();
115+
116+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
117+ size_t outer = outer_idx * outer_stride;
118+ for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
119+ size_t base = outer + inner_idx;
120+
121+ float * p_in_data = (float *)&p_inpf[base];
122+ float * p_out_data = (float *)&out_dataf[base];
123+
124+ ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
125+
126+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
127+ }
128+ }
129+ return out;
130+ }
131+
102132 int * p_out =
103133 (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
104134
@@ -109,7 +139,7 @@ Tensor& _softmax_out(
109139
110140 ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
111141
112- WORD32 ret_val = xa_nn_transpose_32_32 (
142+ ret_val = xa_nn_transpose_32_32 (
113143 p_out,
114144 p_out_shape,
115145 p_inp,
@@ -142,9 +172,7 @@ Tensor& _softmax_out(
142172 p_permute_vec,
143173 num_out_dims,
144174 num_inp_dims);
145-
146175 ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
147-
148176 return out;
149177 }
150178
0 commit comments