Skip to content

Commit 01671f4

Browse files
authored
[NPU] Add bf16 register (#1403)
1 parent a7f726a commit 01671f4

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

backends/npu/kernels/activation_kernel.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1916,14 +1916,16 @@ PD_REGISTER_PLUGIN_KERNEL(gelu,
19161916
ALL_LAYOUT,
19171917
custom_kernel::GeluKernel,
19181918
float,
1919-
phi::dtype::float16) {}
1919+
phi::dtype::float16,
1920+
phi::dtype::bfloat16) {}
19201921

19211922
PD_REGISTER_PLUGIN_KERNEL(gelu_grad,
19221923
npu,
19231924
ALL_LAYOUT,
19241925
custom_kernel::GeluGradKernel,
19251926
float,
1926-
phi::dtype::float16) {}
1927+
phi::dtype::float16,
1928+
phi::dtype::bfloat16) {}
19271929

19281930
PD_REGISTER_PLUGIN_KERNEL(tanh,
19291931
npu,
@@ -1945,6 +1947,7 @@ PD_REGISTER_PLUGIN_KERNEL(sigmoid,
19451947
custom_kernel::SigmoidKernel,
19461948
float,
19471949
phi::dtype::float16,
1950+
phi::dtype::bfloat16,
19481951
double) {}
19491952

19501953
PD_REGISTER_PLUGIN_KERNEL(sigmoid_grad,
@@ -1953,6 +1956,7 @@ PD_REGISTER_PLUGIN_KERNEL(sigmoid_grad,
19531956
custom_kernel::SigmoidGradKernel,
19541957
float,
19551958
phi::dtype::float16,
1959+
phi::dtype::bfloat16,
19561960
double) {}
19571961

19581962
PD_REGISTER_PLUGIN_KERNEL(sqrt,

backends/npu/kernels/conv2d_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,11 +492,13 @@ PD_REGISTER_PLUGIN_KERNEL(conv2d,
492492
ALL_LAYOUT,
493493
custom_kernel::Conv2dKernel,
494494
float,
495-
phi::dtype::float16) {}
495+
phi::dtype::float16,
496+
phi::dtype::bfloat16) {}
496497

497498
PD_REGISTER_PLUGIN_KERNEL(conv2d_grad,
498499
npu,
499500
ALL_LAYOUT,
500501
custom_kernel::Conv2DGradKernel,
501502
float,
502-
phi::dtype::float16) {}
503+
phi::dtype::float16,
504+
phi::dtype::bfloat16) {}

backends/npu/kernels/grid_sample_kernel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,7 @@ void GridSampleGradKernel(const Context& dev_ctx,
301301
if (find_it != paddlingModeMap.end()) {
302302
paddle_mode_to_int = find_it->second;
303303
}
304-
EXEC_NPU_CMD(aclnnGridSamplerd2DBackward,
304+
EXEC_NPU_CMD(aclnnGridSampler2DBackward,
305305
dev_ctx,
306306
out_grad,
307307
x,

backends/npu/kernels/scatter_kernel.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,8 @@ PD_REGISTER_PLUGIN_KERNEL(scatter,
150150
float,
151151
int64_t,
152152
int,
153-
phi::dtype::float16) {}
153+
phi::dtype::float16,
154+
phi::dtype::bfloat16) {}
154155

155156
PD_REGISTER_PLUGIN_KERNEL(scatter_grad,
156157
npu,
@@ -159,4 +160,5 @@ PD_REGISTER_PLUGIN_KERNEL(scatter_grad,
159160
float,
160161
int64_t,
161162
int,
162-
phi::dtype::float16) {}
163+
phi::dtype::float16,
164+
phi::dtype::bfloat16) {}

0 commit comments

Comments
 (0)