Skip to content

Commit 70b14ac

Browse files
authored
[Bug Fix] Fix CastKernel for low-precision to complex type conversions (PaddlePaddle#75930)
这些新增的特化让 CPU 版本的 `CastKernel` 在把 `float8_e5m2`、`float8_e4m3fn`、`bfloat16`、`float16` 等低精度类型转换成复数类型 (`complex64`/`complex128`) 时能够直接工作。之前模板的默认实现只会走 `static_cast<OutT>(in)`,对这些自定义浮点类型来说没有直达的构造函数到复数类型,会在编译期或运行期失败。现在通过先把它们显式转换成 `float` 或 `double` 来构造复数,补齐了这些 cast 组合,修复了 cast op 在上面这些输入/输出类型上的缺口。
1 parent 4887335 commit 70b14ac

File tree

1 file changed

+57
-0
lines changed

1 file changed

+57
-0
lines changed

paddle/phi/kernels/cpu/cast_impl.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,63 @@ struct CastOpTransformFunctor {
2525
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
2626
};
2727

28+
template <>
29+
struct CastOpTransformFunctor<::phi::dtype::float8_e5m2, ::phi::complex64> {
30+
HOSTDEVICE ::phi::complex64 operator()(::phi::dtype::float8_e5m2 in) const {
31+
return ::phi::complex64(static_cast<float>(in));
32+
}
33+
};
34+
35+
template <>
36+
struct CastOpTransformFunctor<::phi::dtype::float8_e5m2, ::phi::complex128> {
37+
HOSTDEVICE ::phi::complex128 operator()(::phi::dtype::float8_e5m2 in) const {
38+
return ::phi::complex128(static_cast<double>(in));
39+
}
40+
};
41+
42+
template <>
43+
struct CastOpTransformFunctor<::phi::dtype::float8_e4m3fn, ::phi::complex64> {
44+
HOSTDEVICE ::phi::complex64 operator()(::phi::dtype::float8_e4m3fn in) const {
45+
return ::phi::complex64(static_cast<float>(in));
46+
}
47+
};
48+
49+
template <>
50+
struct CastOpTransformFunctor<::phi::dtype::float8_e4m3fn, ::phi::complex128> {
51+
HOSTDEVICE ::phi::complex128 operator()(
52+
::phi::dtype::float8_e4m3fn in) const {
53+
return ::phi::complex128(static_cast<double>(in));
54+
}
55+
};
56+
57+
template <>
58+
struct CastOpTransformFunctor<::phi::dtype::bfloat16, ::phi::complex64> {
59+
HOSTDEVICE ::phi::complex64 operator()(::phi::dtype::bfloat16 in) const {
60+
return ::phi::complex64(static_cast<float>(in));
61+
}
62+
};
63+
64+
template <>
65+
struct CastOpTransformFunctor<::phi::dtype::bfloat16, ::phi::complex128> {
66+
HOSTDEVICE ::phi::complex128 operator()(::phi::dtype::bfloat16 in) const {
67+
return ::phi::complex128(static_cast<double>(in));
68+
}
69+
};
70+
71+
template <>
72+
struct CastOpTransformFunctor<::phi::dtype::float16, ::phi::complex64> {
73+
HOSTDEVICE ::phi::complex64 operator()(::phi::dtype::float16 in) const {
74+
return ::phi::complex64(static_cast<float>(in));
75+
}
76+
};
77+
78+
template <>
79+
struct CastOpTransformFunctor<::phi::dtype::float16, ::phi::complex128> {
80+
HOSTDEVICE ::phi::complex128 operator()(::phi::dtype::float16 in) const {
81+
return ::phi::complex128(static_cast<double>(in));
82+
}
83+
};
84+
2885
template <typename InT, typename OutT>
2986
void CastKernelImpl(const CPUContext& dev_ctx,
3087
const DenseTensor& x,

0 commit comments

Comments
 (0)