@@ -72,7 +72,6 @@ Tensor& _softmax_out(
72
72
if (optimized) {
73
73
int * p_inp = (int *)in.const_data_ptr <float >();
74
74
int * out_data = (int *)out.mutable_data_ptr <float >();
75
-
76
75
int num_inp_dims = in.dim ();
77
76
int num_out_dims = num_inp_dims;
78
77
@@ -99,6 +98,37 @@ Tensor& _softmax_out(
99
98
100
99
outer_stride = size;
101
100
101
+ WORD32 ret_val = 0 ;
102
+
103
+ // Check if the input is permuted. If not, then we don't need to transpose
104
+ bool is_permuted = false ;
105
+ for (int i = 0 ; i < num_inp_dims; i++) {
106
+ if (p_permute_vec[i] != i) {
107
+ is_permuted = true ;
108
+ break ;
109
+ }
110
+ }
111
+
112
+ if (!is_permuted) {
113
+ const float * p_inpf = in.const_data_ptr <float >();
114
+ float * out_dataf = out.mutable_data_ptr <float >();
115
+
116
+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
117
+ size_t outer = outer_idx * outer_stride;
118
+ for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
119
+ size_t base = outer + inner_idx;
120
+
121
+ float * p_in_data = (float *)&p_inpf[base];
122
+ float * p_out_data = (float *)&out_dataf[base];
123
+
124
+ ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
125
+
126
+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
127
+ }
128
+ }
129
+ return out;
130
+ }
131
+
102
132
int * p_out =
103
133
(int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
104
134
@@ -109,7 +139,7 @@ Tensor& _softmax_out(
109
139
110
140
ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
111
141
112
- WORD32 ret_val = xa_nn_transpose_32_32 (
142
+ ret_val = xa_nn_transpose_32_32 (
113
143
p_out,
114
144
p_out_shape,
115
145
p_inp,
@@ -142,9 +172,7 @@ Tensor& _softmax_out(
142
172
p_permute_vec,
143
173
num_out_dims,
144
174
num_inp_dims);
145
-
146
175
ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
147
-
148
176
return out;
149
177
}
150
178
0 commit comments