Skip to content

Commit a9e8b29

Browse files
register complex64 for remainder op (PaddlePaddle#2215)
1 parent 40130a5 commit a9e8b29

File tree

1 file changed

+21
-20
lines changed

1 file changed

+21
-20
lines changed

backends/iluvatar_gpu/kernels/cuda_kernels/elementwise_kernel_register.cc

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,19 @@
1515
#include "paddle/phi/core/kernel_registry.h"
1616
#include "paddle/phi/kernels/kps/elementwise_kernel.cu" // NOLINT
1717

18+
using float16 = phi::dtype::float16;
19+
using bfloat16 = phi::dtype::bfloat16;
20+
using complex64 = ::phi::dtype::complex<float>;
21+
1822
PD_CUSTOM_KERNEL_REGISTER(maximum,
1923
iluvatar_gpu,
2024
ALL_LAYOUT,
2125
phi::MaximumKernel,
2226
float,
2327
int,
2428
int64_t,
25-
phi::dtype::float16,
26-
phi::dtype::bfloat16) {}
29+
float16,
30+
bfloat16) {}
2731

2832
PD_CUSTOM_KERNEL_REGISTER(minimum,
2933
iluvatar_gpu,
@@ -32,8 +36,8 @@ PD_CUSTOM_KERNEL_REGISTER(minimum,
3236
float,
3337
int,
3438
int64_t,
35-
phi::dtype::float16,
36-
phi::dtype::bfloat16) {}
39+
float16,
40+
bfloat16) {}
3741

3842
PD_CUSTOM_KERNEL_REGISTER(remainder,
3943
iluvatar_gpu,
@@ -42,8 +46,9 @@ PD_CUSTOM_KERNEL_REGISTER(remainder,
4246
float,
4347
int,
4448
int64_t,
45-
phi::dtype::float16,
46-
phi::dtype::bfloat16) {}
49+
float16,
50+
bfloat16,
51+
complex64) {}
4752

4853
PD_CUSTOM_KERNEL_REGISTER(floor_divide,
4954
iluvatar_gpu,
@@ -55,8 +60,8 @@ PD_CUSTOM_KERNEL_REGISTER(floor_divide,
5560
int,
5661
int64_t,
5762
float,
58-
phi::dtype::float16,
59-
phi::dtype::bfloat16) {}
63+
float16,
64+
bfloat16) {}
6065

6166
PD_CUSTOM_KERNEL_REGISTER(elementwise_pow,
6267
iluvatar_gpu,
@@ -65,8 +70,8 @@ PD_CUSTOM_KERNEL_REGISTER(elementwise_pow,
6570
float,
6671
int,
6772
int64_t,
68-
phi::dtype::float16,
69-
phi::dtype::bfloat16) {}
73+
float16,
74+
bfloat16) {}
7075

7176
PD_CUSTOM_KERNEL_REGISTER(copysign,
7277
iluvatar_gpu,
@@ -79,12 +84,8 @@ PD_CUSTOM_KERNEL_REGISTER(copysign,
7984
int,
8085
int64_t,
8186
float,
82-
phi::dtype::float16,
83-
phi::dtype::bfloat16) {}
84-
85-
using float16 = phi::dtype::float16;
86-
using bfloat16 = phi::dtype::bfloat16;
87-
using complex64 = ::phi::dtype::complex<float>;
87+
float16,
88+
bfloat16) {}
8889

8990
PD_CUSTOM_KERNEL_REGISTER(fmax,
9091
iluvatar_gpu,
@@ -127,8 +128,8 @@ PD_CUSTOM_KERNEL_REGISTER(add,
127128
uint8_t,
128129
int8_t,
129130
int64_t,
130-
phi::dtype::float16,
131-
phi::dtype::bfloat16,
131+
float16,
132+
bfloat16,
132133
complex64) {}
133134

134135
PD_CUSTOM_KERNEL_REGISTER(grad_add,
@@ -142,8 +143,8 @@ PD_CUSTOM_KERNEL_REGISTER(grad_add,
142143
uint8_t,
143144
int8_t,
144145
int64_t,
145-
phi::dtype::float16,
146-
phi::dtype::bfloat16,
146+
float16,
147+
bfloat16,
147148
complex64) {}
148149

149150
PD_CUSTOM_KERNEL_REGISTER(divide,

0 commit comments

Comments
 (0)