Skip to content

Commit f599ce7

Browse files
authored
[MLU] Fix reduce_mean bug (#1354)
1 parent 62340bd commit f599ce7

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

backends/mlu/kernels/reduce_mean_kernel.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ void MeanRawKernel(const Context& dev_ctx,
2323
bool keep_dim,
2424
bool reduce_all,
2525
phi::DenseTensor* out) {
26+
if (x.dims().size() == 0) {
27+
TensorCopy(dev_ctx, x, false, out);
28+
return;
29+
}
2630
MLUReduceOp<T>(
2731
dev_ctx, x, axes.GetData(), keep_dim, reduce_all, "reduce_mean", out);
2832
}
@@ -47,6 +51,11 @@ void MeanGradKernel(const Context& dev_ctx,
4751
phi::DenseTensor* x_grad) {
4852
dev_ctx.template Alloc<T>(x_grad);
4953

54+
if (x.dims().size() == 0) {
55+
TensorCopy(dev_ctx, out_grad, false, x_grad);
56+
return;
57+
}
58+
5059
auto reduce_dims = axes.GetData();
5160
auto input_dims = phi::vectorize(x.dims());
5261

0 commit comments

Comments
 (0)