Skip to content

Commit 9817aad

Browse files
authored
[NPU]speed up bn and momentum (#1487)
1 parent 7a27667 commit 9817aad

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

backends/npu/kernels/batch_norm_kernel.cc

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ void TransposeKernel(const Context& dev_ctx,
3232
const std::vector<int>& axis,
3333
phi::DenseTensor* out);
3434

35+
template <typename T, typename Context>
36+
void FullLikeKernel(const Context& dev_ctx,
37+
const phi::DenseTensor& x,
38+
const phi::Scalar& val,
39+
phi::DataType dtype,
40+
phi::DenseTensor* out);
41+
3542
template <typename T, typename Context>
3643
void AclopBatchNormKernel(const Context& dev_ctx,
3744
const phi::DenseTensor& x,
@@ -536,18 +543,25 @@ void BatchNormKernel(const Context& dev_ctx,
536543
aclnnInplaceAdd, dev_ctx, *variance_out, *saved_variance, momentum_p);
537544
auto stream = dev_ctx.stream();
538545

539-
const auto& adds_runner =
540-
NpuOpRunner("Adds",
541-
{*saved_variance},
542-
{*saved_variance},
543-
{{"value", static_cast<float>(epsilon)}});
544-
adds_runner.Run(stream);
545-
const auto& inv_runner =
546-
NpuOpRunner("Inv", {*saved_variance}, {*saved_variance}, {});
547-
inv_runner.Run(stream);
548-
const auto& sqrt_ruuner =
549-
NpuOpRunner("Sqrt", {*saved_variance}, {*saved_variance}, {});
550-
sqrt_ruuner.Run(stream);
546+
phi::Scalar one_scalar = static_cast<float>(1.0);
547+
548+
phi::DenseTensor epsilon_tensor;
549+
epsilon_tensor.set_meta(saved_variance->meta());
550+
custom_kernel::FullLikeKernel<T, Context>(
551+
dev_ctx,
552+
*saved_variance,
553+
phi::Scalar(static_cast<float>(epsilon)),
554+
saved_variance->dtype(),
555+
&epsilon_tensor);
556+
557+
EXEC_NPU_CMD(aclnnAdd,
558+
dev_ctx,
559+
*saved_variance,
560+
epsilon_tensor,
561+
one_scalar,
562+
*saved_variance);
563+
564+
EXEC_NPU_CMD(aclnnInplaceRsqrt, dev_ctx, *saved_variance);
551565
}
552566
}
553567

backends/npu/kernels/momentum_kernel.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,13 @@ void MomentumKernel(const Context& dev_ctx,
4848
regularized_grad.Resize(grad.dims());
4949
dev_ctx.template Alloc<T>(&regularized_grad);
5050

51-
const auto& runner1 = NpuOpRunner(
52-
"Muls", {param}, {regularized_grad}, {{"value", regularization_coeff}});
53-
runner1.Run(dev_ctx.stream());
54-
const auto& runner2 =
55-
NpuOpRunner("Add", {regularized_grad, grad}, {regularized_grad}, {});
56-
runner2.Run(dev_ctx.stream());
51+
phi::Scalar regularization_coeff_scalar = regularization_coeff;
52+
EXEC_NPU_CMD(aclnnAdd,
53+
dev_ctx,
54+
grad,
55+
param,
56+
regularization_coeff_scalar,
57+
regularized_grad);
5758
} else {
5859
regularized_grad = grad;
5960
}

0 commit comments

Comments
 (0)