@@ -25,6 +25,63 @@ struct CastOpTransformFunctor {
2525 HOSTDEVICE OutT operator ()(InT in) const { return static_cast <OutT>(in); }
2626};
2727
28+ template <>
29+ struct CastOpTransformFunctor <::phi::dtype::float8_e5m2, ::phi::complex64> {
30+ HOSTDEVICE ::phi::complex64 operator ()(::phi::dtype::float8_e5m2 in) const {
31+ return ::phi::complex64 (static_cast <float >(in));
32+ }
33+ };
34+
35+ template <>
36+ struct CastOpTransformFunctor <::phi::dtype::float8_e5m2, ::phi::complex128> {
37+ HOSTDEVICE ::phi::complex128 operator ()(::phi::dtype::float8_e5m2 in) const {
38+ return ::phi::complex128 (static_cast <double >(in));
39+ }
40+ };
41+
42+ template <>
43+ struct CastOpTransformFunctor <::phi::dtype::float8_e4m3fn, ::phi::complex64> {
44+ HOSTDEVICE ::phi::complex64 operator ()(::phi::dtype::float8_e4m3fn in) const {
45+ return ::phi::complex64 (static_cast <float >(in));
46+ }
47+ };
48+
49+ template <>
50+ struct CastOpTransformFunctor <::phi::dtype::float8_e4m3fn, ::phi::complex128> {
51+ HOSTDEVICE ::phi::complex128 operator ()(
52+ ::phi::dtype::float8_e4m3fn in) const {
53+ return ::phi::complex128 (static_cast <double >(in));
54+ }
55+ };
56+
57+ template <>
58+ struct CastOpTransformFunctor <::phi::dtype::bfloat16, ::phi::complex64> {
59+ HOSTDEVICE ::phi::complex64 operator ()(::phi::dtype::bfloat16 in) const {
60+ return ::phi::complex64 (static_cast <float >(in));
61+ }
62+ };
63+
64+ template <>
65+ struct CastOpTransformFunctor <::phi::dtype::bfloat16, ::phi::complex128> {
66+ HOSTDEVICE ::phi::complex128 operator ()(::phi::dtype::bfloat16 in) const {
67+ return ::phi::complex128 (static_cast <double >(in));
68+ }
69+ };
70+
71+ template <>
72+ struct CastOpTransformFunctor <::phi::dtype::float16, ::phi::complex64> {
73+ HOSTDEVICE ::phi::complex64 operator ()(::phi::dtype::float16 in) const {
74+ return ::phi::complex64 (static_cast <float >(in));
75+ }
76+ };
77+
78+ template <>
79+ struct CastOpTransformFunctor <::phi::dtype::float16, ::phi::complex128> {
80+ HOSTDEVICE ::phi::complex128 operator ()(::phi::dtype::float16 in) const {
81+ return ::phi::complex128 (static_cast <double >(in));
82+ }
83+ };
84+
2885template <typename InT, typename OutT>
2986void CastKernelImpl (const CPUContext& dev_ctx,
3087 const DenseTensor& x,
0 commit comments