@@ -32,6 +32,13 @@ void TransposeKernel(const Context& dev_ctx,
32
32
const std::vector<int >& axis,
33
33
phi::DenseTensor* out);
34
34
35
+ template <typename T, typename Context>
36
+ void FullLikeKernel (const Context& dev_ctx,
37
+ const phi::DenseTensor& x,
38
+ const phi::Scalar& val,
39
+ phi::DataType dtype,
40
+ phi::DenseTensor* out);
41
+
35
42
template <typename T, typename Context>
36
43
void AclopBatchNormKernel (const Context& dev_ctx,
37
44
const phi::DenseTensor& x,
@@ -536,18 +543,25 @@ void BatchNormKernel(const Context& dev_ctx,
536
543
aclnnInplaceAdd, dev_ctx, *variance_out, *saved_variance, momentum_p);
537
544
auto stream = dev_ctx.stream ();
538
545
539
- const auto & adds_runner =
540
- NpuOpRunner (" Adds" ,
541
- {*saved_variance},
542
- {*saved_variance},
543
- {{" value" , static_cast <float >(epsilon)}});
544
- adds_runner.Run (stream);
545
- const auto & inv_runner =
546
- NpuOpRunner (" Inv" , {*saved_variance}, {*saved_variance}, {});
547
- inv_runner.Run (stream);
548
- const auto & sqrt_ruuner =
549
- NpuOpRunner (" Sqrt" , {*saved_variance}, {*saved_variance}, {});
550
- sqrt_ruuner.Run (stream);
546
+ phi::Scalar one_scalar = static_cast <float >(1.0 );
547
+
548
+ phi::DenseTensor epsilon_tensor;
549
+ epsilon_tensor.set_meta (saved_variance->meta ());
550
+ custom_kernel::FullLikeKernel<T, Context>(
551
+ dev_ctx,
552
+ *saved_variance,
553
+ phi::Scalar (static_cast <float >(epsilon)),
554
+ saved_variance->dtype (),
555
+ &epsilon_tensor);
556
+
557
+ EXEC_NPU_CMD (aclnnAdd,
558
+ dev_ctx,
559
+ *saved_variance,
560
+ epsilon_tensor,
561
+ one_scalar,
562
+ *saved_variance);
563
+
564
+ EXEC_NPU_CMD (aclnnInplaceRsqrt, dev_ctx, *saved_variance);
551
565
}
552
566
}
553
567
0 commit comments