Skip to content

Commit 9a4b990

Browse files
[NPU] fix BN question (#1340)
1 parent 0148a3d commit 9a4b990

File tree

2 files changed

+9
-25
lines changed

2 files changed

+9
-25
lines changed

backends/npu/kernels/batch_norm_kernel.cc

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -525,31 +525,16 @@ void BatchNormKernel(const Context& dev_ctx,
525525
if (training) {
526526
// CANN mean_out/var_out and paddlepaddle-cpu mean_out/var_out are
527527
// defferent.
528+
phi::Scalar momentum_f = static_cast<float>(momentum);
529+
phi::Scalar momentum_p = static_cast<float>(1 - momentum);
530+
EXEC_NPU_CMD(aclnnMuls, dev_ctx, tmp_running_mean, momentum_f, *mean_out);
531+
EXEC_NPU_CMD(aclnnInplaceAdd, dev_ctx, *mean_out, *saved_mean, momentum_p);
532+
533+
EXEC_NPU_CMD(
534+
aclnnMuls, dev_ctx, tmp_running_var, momentum_f, *variance_out);
535+
EXEC_NPU_CMD(
536+
aclnnInplaceAdd, dev_ctx, *variance_out, *saved_variance, momentum_p);
528537
auto stream = dev_ctx.stream();
529-
const auto& mean_muls_runner =
530-
NpuOpRunner("Muls",
531-
{tmp_running_mean},
532-
{*mean_out},
533-
{{"value", static_cast<float>(momentum)}});
534-
mean_muls_runner.Run(stream);
535-
const auto& mean_axpy_runner =
536-
NpuOpRunner("Axpy",
537-
{*mean_out, *saved_mean},
538-
{*mean_out},
539-
{{"alpha", static_cast<float>(1 - momentum)}});
540-
mean_axpy_runner.Run(stream);
541-
const auto& var_muls_runner =
542-
NpuOpRunner("Muls",
543-
{tmp_running_var},
544-
{*variance_out},
545-
{{"value", static_cast<float>(momentum)}});
546-
var_muls_runner.Run(stream);
547-
const auto& var_axpy_runner =
548-
NpuOpRunner("Axpy",
549-
{*variance_out, *saved_variance},
550-
{*variance_out},
551-
{{"alpha", static_cast<float>(1 - momentum)}});
552-
var_axpy_runner.Run(stream);
553538

554539
const auto& adds_runner =
555540
NpuOpRunner("Adds",

backends/npu/tools/disable_ut_npu_910b

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
disable_ut_npu
2-
test_batch_norm_op_npu
32
test_check_nan_inf_op_npu
43
test_conv3d_op_npu
54
test_contiguous_op_npu

0 commit comments

Comments
 (0)