@@ -56,7 +56,23 @@ void SoftmaxKernel(const Context& dev_ctx,
56
56
(custom_kernel::AclopSoftmaxKernel<T, Context>(dev_ctx, x, axis, out)));
57
57
dev_ctx.template Alloc <T>(out);
58
58
int64_t dim = static_cast <int64_t >(axis);
59
- EXEC_NPU_CMD (aclnnSoftmax, dev_ctx, x, dim, *out);
59
+
60
+ phi::DenseTensor x_trans (x);
61
+ if (x.dims ().size () == 0 ) {
62
+ phi::DenseTensorMeta meta_1 = {x.dtype (), phi::make_ddim ({1 })};
63
+ x_trans.set_meta (meta_1);
64
+ }
65
+
66
+ auto out_dims = out->dims ();
67
+ if (out_dims.size () == 0 ) {
68
+ out->Resize (phi::make_ddim ({1 }));
69
+ }
70
+
71
+ EXEC_NPU_CMD (aclnnSoftmax, dev_ctx, x_trans, dim, *out);
72
+
73
+ if (out_dims.size () == 0 ) {
74
+ out->Resize (out_dims);
75
+ }
60
76
}
61
77
62
78
template <typename T, typename Context>
@@ -148,31 +164,52 @@ void SoftmaxGradKernel(const Context& dev_ctx,
148
164
dev_ctx.template Alloc <T>(x_grad);
149
165
int64_t dim = static_cast <int64_t >(axis);
150
166
167
+ phi::DenseTensor x_trans (out_grad);
168
+ if (out_grad.dims ().size () == 0 ) {
169
+ phi::DenseTensorMeta meta_1 = {out_grad.dtype (), phi::make_ddim ({1 })};
170
+ x_trans.set_meta (meta_1);
171
+ }
172
+
151
173
phi::DenseTensor cast_x;
152
- if (out_grad .dtype () == phi::DataType::FLOAT64) {
153
- phi::DenseTensorMeta meta (out_grad .meta ());
174
+ if (x_trans .dtype () == phi::DataType::FLOAT64) {
175
+ phi::DenseTensorMeta meta (x_trans .meta ());
154
176
meta.dtype = phi::DataType::FLOAT32;
155
177
cast_x.set_meta (meta);
156
178
157
179
custom_kernel::CastKernel<T, Context>(
158
- dev_ctx, out_grad , phi::DataType::FLOAT32, &cast_x);
180
+ dev_ctx, x_trans , phi::DataType::FLOAT32, &cast_x);
159
181
} else {
160
- cast_x = out_grad;
182
+ cast_x = x_trans;
183
+ }
184
+
185
+ phi::DenseTensor y_trans (out);
186
+ if (out.dims ().size () == 0 ) {
187
+ phi::DenseTensorMeta meta_1 = {out.dtype (), phi::make_ddim ({1 })};
188
+ y_trans.set_meta (meta_1);
161
189
}
162
190
163
191
phi::DenseTensor cast_y;
164
- if (out .dtype () == phi::DataType::FLOAT64) {
165
- phi::DenseTensorMeta meta (out .meta ());
192
+ if (y_trans .dtype () == phi::DataType::FLOAT64) {
193
+ phi::DenseTensorMeta meta (y_trans .meta ());
166
194
meta.dtype = phi::DataType::FLOAT32;
167
195
cast_y.set_meta (meta);
168
196
169
197
custom_kernel::CastKernel<T, Context>(
170
- dev_ctx, out , phi::DataType::FLOAT32, &cast_y);
198
+ dev_ctx, y_trans , phi::DataType::FLOAT32, &cast_y);
171
199
} else {
172
- cast_y = out;
200
+ cast_y = y_trans;
201
+ }
202
+
203
+ auto x_grad_dims = x_grad->dims ();
204
+ if (x_grad_dims.size () == 0 ) {
205
+ x_grad->Resize (phi::make_ddim ({1 }));
173
206
}
174
207
175
208
EXEC_NPU_CMD (aclnnSoftmaxBackward, dev_ctx, cast_x, cast_y, dim, *x_grad);
209
+
210
+ if (x_grad_dims.size () == 0 ) {
211
+ x_grad->Resize (x_grad_dims);
212
+ }
176
213
}
177
214
178
215
} // namespace custom_kernel
0 commit comments