@@ -70,81 +70,86 @@ Tensor& _softmax_out(
7070 optimized = false ;
7171
7272 if (optimized) {
73- int * p_inp = (int *)in.const_data_ptr <float >();
74- int * out_data = (int *)out.mutable_data_ptr <float >();
73+ if (dim == in.dim () - 1 ) {
74+ const float * p_inp = in.const_data_ptr <float >();
75+ float * out_data = out.mutable_data_ptr <float >();
7576
76- int num_inp_dims = in.dim ();
77- int num_out_dims = num_inp_dims;
77+ WORD32 ret_val = xa_nn_vec_softmax_f32_f32 (out_data, p_inp, size);
78+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
79+ } else {
80+ int * p_inp = (int *)in.const_data_ptr <float >();
81+ int * out_data = (int *)out.mutable_data_ptr <float >();
82+ int num_inp_dims = in.dim ();
83+ int num_out_dims = num_inp_dims;
7884
79- int p_inp_shape[kNnlibMaxDim ];
80- int p_out_shape[kNnlibMaxDim ];
81- int p_permute_vec[kNnlibMaxDim ];
85+ int p_inp_shape[kNnlibMaxDim ];
86+ int p_out_shape[kNnlibMaxDim ];
87+ int p_permute_vec[kNnlibMaxDim ];
8288
83- for (int i = 0 ; i < num_inp_dims; i++)
84- p_inp_shape[i] = in.size (i);
89+ for (int i = 0 ; i < num_inp_dims; i++)
90+ p_inp_shape[i] = in.size (i);
8591
86- for (int i = 0 ; i < num_inp_dims; i++) {
87- if (i == d)
88- p_permute_vec[i] = num_inp_dims - 1 ;
89- else if (i == (num_inp_dims - 1 ))
90- p_permute_vec[num_inp_dims - 1 ] = d;
91- else
92- p_permute_vec[i] = i;
92+ for (int i = 0 ; i < num_inp_dims; i++) {
93+ if (i == d)
94+ p_permute_vec[i] = num_inp_dims - 1 ;
95+ else if (i == (num_inp_dims - 1 ))
96+ p_permute_vec[num_inp_dims - 1 ] = d;
97+ else
98+ p_permute_vec[i] = i;
9399
94- p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
100+ p_out_shape[i] = p_inp_shape[p_permute_vec[i]];
95101
96- if (i != d)
97- outer_size = outer_size * p_inp_shape[i];
98- }
102+ if (i != d)
103+ outer_size = outer_size * p_inp_shape[i];
104+ }
99105
100- outer_stride = size;
106+ outer_stride = size;
101107
102- int * p_out =
103- (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
108+ int * p_out =
109+ (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
104110
105- ET_KERNEL_CHECK (ctx, p_out != nullptr , MemoryAllocationFailed, out);
111+ ET_KERNEL_CHECK (ctx, p_out != nullptr , MemoryAllocationFailed, out);
106112
107- int * p_out1 =
108- (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
113+ int * p_out1 =
114+ (int *)kernels::allocate_temp_memory (ctx, out.numel () * sizeof (int ));
109115
110- ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
116+ ET_KERNEL_CHECK (ctx, p_out1 != nullptr , MemoryAllocationFailed, out);
111117
112- WORD32 ret_val = xa_nn_transpose_32_32 (
113- p_out,
114- p_out_shape,
115- p_inp,
116- p_inp_shape,
117- p_permute_vec,
118- num_out_dims,
119- num_inp_dims);
118+ WORD32 ret_val = xa_nn_transpose_32_32 (
119+ p_out,
120+ p_out_shape,
121+ p_inp,
122+ p_inp_shape,
123+ p_permute_vec,
124+ num_out_dims,
125+ num_inp_dims);
126+
127+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
120128
121- ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
129+ for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
130+ size_t outer = outer_idx * outer_stride;
131+ for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
132+ size_t base = outer + inner_idx;
122133
123- for (size_t outer_idx = 0 ; outer_idx < outer_size; ++outer_idx) {
124- size_t outer = outer_idx * outer_stride;
125- for (size_t inner_idx = 0 ; inner_idx < stride; ++inner_idx) {
126- size_t base = outer + inner_idx;
134+ float * p_in_data = (float *)&p_out[base];
135+ float * p_out_data = (float *)&p_out1[base];
127136
128- float * p_in_data = (float *)&p_out[base];
129- float * p_out_data = (float *)&p_out1[base];
137+ ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
130138
131- ret_val = xa_nn_vec_softmax_f32_f32 (p_out_data, p_in_data, size);
139+ ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
140+ }
141+ }
132142
143+ ret_val = xa_nn_transpose_32_32 (
144+ out_data,
145+ p_inp_shape,
146+ p_out1,
147+ p_out_shape,
148+ p_permute_vec,
149+ num_out_dims,
150+ num_inp_dims);
133151 ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
134- }
135152 }
136-
137- ret_val = xa_nn_transpose_32_32 (
138- out_data,
139- p_inp_shape,
140- p_out1,
141- p_out_shape,
142- p_permute_vec,
143- num_out_dims,
144- num_inp_dims);
145-
146- ET_KERNEL_CHECK (ctx, ret_val == 0 , Internal, out);
147-
148153 return out;
149154 }
150155
0 commit comments