Skip to content

Commit 39c676e

Browse files
committed
initial commit
1 parent 3f5705c commit 39c676e

File tree

3 files changed

+6
-4
lines changed

3 files changed

+6
-4
lines changed

paddle/fluid/operators/batch_norm_op.cu.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,9 +270,9 @@ class BatchNormGradKernel<platform::CUDADeviceContext, T>
270270
} // namespace paddle
271271

272272
namespace ops = paddle::operators;
273+
namespace plat = paddle::platform;
273274
REGISTER_OP_CUDA_KERNEL(
274-
batch_norm,
275-
ops::BatchNormKernel<paddle::platform::CUDADeviceContext, float>);
275+
batch_norm, ops::BatchNormKernel<plat::CUDADeviceContext, float>,
276+
ops::BatchNormKernel<plat::CUDADeviceContext, plat::float16>);
276277
REGISTER_OP_CUDA_KERNEL(
277-
batch_norm_grad,
278-
ops::BatchNormGradKernel<paddle::platform::CUDADeviceContext, float>);
278+
batch_norm_grad, ops::BatchNormGradKernel<plat::CUDADeviceContext, float>);

paddle/fluid/operators/math/math_function.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ void axpy<platform::CPUDeviceContext, double>(
278278
cblas_daxpy(n, alpha, x, 1, y, 1);
279279
}
280280

281+
template struct SetConstant<platform::CPUDeviceContext, platform::float16>;
281282
template struct SetConstant<platform::CPUDeviceContext, float>;
282283
template struct SetConstant<platform::CPUDeviceContext, double>;
283284
template struct SetConstant<platform::CPUDeviceContext, int>;

paddle/fluid/operators/math/math_function.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,7 @@ void axpy<platform::CUDADeviceContext, double>(
348348
&alpha, x, 1, y, 1));
349349
}
350350

351+
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
351352
template struct SetConstant<platform::CUDADeviceContext, float>;
352353
template struct SetConstant<platform::CUDADeviceContext, double>;
353354
template struct SetConstant<platform::CUDADeviceContext, int>;

0 commit comments

Comments
 (0)