@@ -525,31 +525,16 @@ void BatchNormKernel(const Context& dev_ctx,
525
525
if (training) {
526
526
// CANN mean_out/var_out and paddlepaddle-cpu mean_out/var_out are
527
527
// defferent.
528
+ phi::Scalar momentum_f = static_cast <float >(momentum);
529
+ phi::Scalar momentum_p = static_cast <float >(1 - momentum);
530
+ EXEC_NPU_CMD (aclnnMuls, dev_ctx, tmp_running_mean, momentum_f, *mean_out);
531
+ EXEC_NPU_CMD (aclnnInplaceAdd, dev_ctx, *mean_out, *saved_mean, momentum_p);
532
+
533
+ EXEC_NPU_CMD (
534
+ aclnnMuls, dev_ctx, tmp_running_var, momentum_f, *variance_out);
535
+ EXEC_NPU_CMD (
536
+ aclnnInplaceAdd, dev_ctx, *variance_out, *saved_variance, momentum_p);
528
537
auto stream = dev_ctx.stream ();
529
- const auto & mean_muls_runner =
530
- NpuOpRunner (" Muls" ,
531
- {tmp_running_mean},
532
- {*mean_out},
533
- {{" value" , static_cast <float >(momentum)}});
534
- mean_muls_runner.Run (stream);
535
- const auto & mean_axpy_runner =
536
- NpuOpRunner (" Axpy" ,
537
- {*mean_out, *saved_mean},
538
- {*mean_out},
539
- {{" alpha" , static_cast <float >(1 - momentum)}});
540
- mean_axpy_runner.Run (stream);
541
- const auto & var_muls_runner =
542
- NpuOpRunner (" Muls" ,
543
- {tmp_running_var},
544
- {*variance_out},
545
- {{" value" , static_cast <float >(momentum)}});
546
- var_muls_runner.Run (stream);
547
- const auto & var_axpy_runner =
548
- NpuOpRunner (" Axpy" ,
549
- {*variance_out, *saved_variance},
550
- {*variance_out},
551
- {{" alpha" , static_cast <float >(1 - momentum)}});
552
- var_axpy_runner.Run (stream);
553
538
554
539
const auto & adds_runner =
555
540
NpuOpRunner (" Adds" ,
0 commit comments