@@ -70,81 +70,86 @@ Tensor& _softmax_out(
7070 optimized = false ;
7171
7272 if (optimized) {
73- int * p_inp = (int *)in.const_data_ptr <float >();
74- int * out_data = (int *)out.mutable_data_ptr <float >();
75-
76- int num_inp_dims = in.dim ();
77- int num_out_dims = num_inp_dims;
78-
79- int p_inp_shape[kNnlibMaxDim ];
80- int p_out_shape[kNnlibMaxDim ];
81- int p_permute_vec[kNnlibMaxDim ];
82-
83- for (int i = 0 ; i < num_inp_dims; i++)
84- p_inp_shape[i] = in.size (i);
85-
86- for (int i = 0 ; i < num_inp_dims; i++) {
87- if (i == d)
88- p_permute_vec[i] = num_inp_dims - 1 ;
89- else if (i == (num_inp_dims - 1 ))
90- p_permute_vec[num_inp_dims - 1 ] = d;
91- else
92- p_permute_vec[i] = i;
93-
94- p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
95-
96- if (i != d)
97- outer_size = outer_size * p_inp_shape[i];
98- }
73+ if (dim == in.dim () - 1 ) {
74+ const float * p_inp = in.const_data_ptr <float >();
75+ float * out_data = out.mutable_data_ptr <float >();
76+
77+ WORD32 ret_val = xa_nn_vec_softmax_f32_f32 (out_data, p_inp, size);
78+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
79+ } else {
80+ int * p_inp = (int *)in.const_data_ptr <float >();
81+ int * out_data = (int *)out.mutable_data_ptr <float >();
82+ int num_inp_dims = in.dim ();
83+ int num_out_dims = num_inp_dims;
84+
85+ int p_inp_shape[kNnlibMaxDim ];
86+ int p_out_shape[kNnlibMaxDim ];
87+ int p_permute_vec[kNnlibMaxDim ];
88+
89+ for (int i = 0 ; i < num_inp_dims; i++)
90+ p_inp_shape[i] = in.size (i);
91+
92+ for (int i = 0 ; i < num_inp_dims; i++) {
93+ if (i == d)
94+ p_permute_vec[i] = num_inp_dims - 1 ;
95+ else if (i == (num_inp_dims - 1 ))
96+ p_permute_vec[num_inp_dims - 1 ] = d;
97+ else
98+ p_permute_vec[i] = i;
99+
100+ p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
101+
102+ if (i != d)
103+ outer_size = outer_size * p_inp_shape[i];
104+ }
99105
100- outer_stride = size;
106+ outer_stride = size;
101107
102- int * p_out =
103- (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
108+ int * p_out =
109+ (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
104110
105- ET_KERNEL_CHECK (ctx, p_out != nullptr , MemoryAllocationFailed, out);
111+ ET_KERNEL_CHECK (ctx, p_out != nullptr , MemoryAllocationFailed, out);
106112
107- int * p_out1 =
108- (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
113+ int * p_out1 =
114+ (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
109115
110- ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
116+ ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
111117
112- WORD32 ret_val = xa_nn_transpose_32_32 (
113- p_out,
114- p_out_shape,
115- p_inp,
116- p_inp_shape,
117- p_permute_vec,
118- num_out_dims,
119- num_inp_dims);
118+ WORD32 ret_val = xa_nn_transpose_32_32 (
119+ p_out,
120+ p_out_shape,
121+ p_inp,
122+ p_inp_shape,
123+ p_permute_vec,
124+ num_out_dims,
125+ num_inp_dims);
120126
121- ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
127+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
122128
123- for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
124- size_t outer = outer_idx * outer_stride;
125- for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
126- size_t base = outer + inner_idx;
129+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
130+ size_t outer = outer_idx * outer_stride;
131+ for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
132+ size_t base = outer + inner_idx;
127133
128- float * p_in_data = (float *)&p_out[base];
129- float * p_out_data = (float *)&p_out1[base];
134+ float * p_in_data = (float *)&p_out[base];
135+ float * p_out_data = (float *)&p_out1[base];
130136
131- ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
137+ ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
132138
133- ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
139+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
140+ }
134141 }
135- }
136-
137- ret_val = xa_nn_transpose_32_32 (
138- out_data,
139- p_inp_shape,
140- p_out1,
141- p_out_shape,
142- p_permute_vec,
143- num_out_dims,
144- num_inp_dims);
145-
146- ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
147142
143+ ret_val = xa_nn_transpose_32_32 (
144+ out_data,
145+ p_inp_shape,
146+ p_out1,
147+ p_out_shape,
148+ p_permute_vec,
149+ num_out_dims,
150+ num_inp_dims);
151+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
152+ }
148153 return out;
149154 }
150155
0 commit comments