@@ -32,6 +32,78 @@ struct CastDataTypeFunctor {
3232 }
3333};
3434
35+ template <>
36+ struct CastDataTypeFunctor <::phi::dtype::float8_e5m2,
37+ ::phi::dtype::complex <float >> {
38+ HOSTDEVICE ::phi::dtype::complex <float > operator ()(
39+ ::phi::dtype::float8_e5m2 in) const {
40+ return ::phi::dtype::complex <float >(static_cast <float >(in));
41+ }
42+ };
43+
44+ template <>
45+ struct CastDataTypeFunctor <::phi::dtype::float8_e5m2,
46+ ::phi::dtype::complex <double >> {
47+ HOSTDEVICE ::phi::dtype::complex <double > operator ()(
48+ ::phi::dtype::float8_e5m2 in) const {
49+ return ::phi::dtype::complex <double >(static_cast <double >(in));
50+ }
51+ };
52+
53+ template <>
54+ struct CastDataTypeFunctor <::phi::dtype::float8_e4m3fn,
55+ ::phi::dtype::complex <float >> {
56+ HOSTDEVICE ::phi::dtype::complex <float > operator ()(
57+ ::phi::dtype::float8_e4m3fn in) const {
58+ return ::phi::dtype::complex <float >(static_cast <float >(in));
59+ }
60+ };
61+
62+ template <>
63+ struct CastDataTypeFunctor <::phi::dtype::float8_e4m3fn,
64+ ::phi::dtype::complex <double >> {
65+ HOSTDEVICE ::phi::dtype::complex <double > operator ()(
66+ ::phi::dtype::float8_e4m3fn in) const {
67+ return ::phi::dtype::complex <double >(static_cast <double >(in));
68+ }
69+ };
70+
71+ template <>
72+ struct CastDataTypeFunctor <::phi::dtype::float16,
73+ ::phi::dtype::complex <float >> {
74+ HOSTDEVICE ::phi::dtype::complex <float > operator ()(
75+ ::phi::dtype::float16 in) const {
76+ return ::phi::dtype::complex <float >(static_cast <float >(in));
77+ }
78+ };
79+
80+ template <>
81+ struct CastDataTypeFunctor <::phi::dtype::float16,
82+ ::phi::dtype::complex <double >> {
83+ HOSTDEVICE ::phi::dtype::complex <double > operator ()(
84+ ::phi::dtype::float16 in) const {
85+ return ::phi::dtype::complex <double >(static_cast <double >(in));
86+ }
87+ };
88+
89+ template <>
90+ struct CastDataTypeFunctor <::phi::dtype::bfloat16,
91+ ::phi::dtype::complex <float >> {
92+ HOSTDEVICE ::phi::dtype::complex <float > operator ()(
93+ ::phi::dtype::bfloat16 in) const {
94+ return ::phi::dtype::complex <float >(static_cast <float >(in));
95+ }
96+ };
97+
98+ template <>
99+ struct CastDataTypeFunctor <::phi::dtype::bfloat16,
100+ ::phi::dtype::complex <double >> {
101+ HOSTDEVICE ::phi::dtype::complex <double > operator ()(
102+ ::phi::dtype::bfloat16 in) const {
103+ return ::phi::dtype::complex <double >(static_cast <double >(in));
104+ }
105+ };
106+
35107#if defined(PADDLE_WITH_XPU)
36108
37109template <typename InType, typename OutType>
0 commit comments