Skip to content

Commit 4887335

Browse files
authored
[Bug Fix] Fix CastDataTypeFunctor for low-precision floats to complex types (PaddlePaddle#75934)
这些改动新增了 `CastDataTypeFunctor` 针对 `float8_e5m2`、`float8_e4m3fn`、`float16` 和 `bfloat16` 到 `complex<float>`/`complex<double>` 的专门实现。原本的泛型实现只做 `static_cast<OutType>(in)`,而 `phi::dtype::complex<>` 没有接受这些低精度实数类型的隐式转换,实际调用 `TransDataType` 做类型转换时会缺失对应路径或触发编译/运行错误。现在在转换时先显式转成 `float`/`double` 再构造复数,就能让这些低精度实数张量安全地转换成复数张量,确保在需要复数输出(例如调用复数 kernel)时类型升级可以顺利执行。
1 parent f0747d3 commit 4887335

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

paddle/fluid/framework/data_type_transform.cc

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,78 @@ struct CastDataTypeFunctor {
3232
}
3333
};
3434

35+
template <>
36+
struct CastDataTypeFunctor<::phi::dtype::float8_e5m2,
37+
::phi::dtype::complex<float>> {
38+
HOSTDEVICE ::phi::dtype::complex<float> operator()(
39+
::phi::dtype::float8_e5m2 in) const {
40+
return ::phi::dtype::complex<float>(static_cast<float>(in));
41+
}
42+
};
43+
44+
template <>
45+
struct CastDataTypeFunctor<::phi::dtype::float8_e5m2,
46+
::phi::dtype::complex<double>> {
47+
HOSTDEVICE ::phi::dtype::complex<double> operator()(
48+
::phi::dtype::float8_e5m2 in) const {
49+
return ::phi::dtype::complex<double>(static_cast<double>(in));
50+
}
51+
};
52+
53+
template <>
54+
struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn,
55+
::phi::dtype::complex<float>> {
56+
HOSTDEVICE ::phi::dtype::complex<float> operator()(
57+
::phi::dtype::float8_e4m3fn in) const {
58+
return ::phi::dtype::complex<float>(static_cast<float>(in));
59+
}
60+
};
61+
62+
template <>
63+
struct CastDataTypeFunctor<::phi::dtype::float8_e4m3fn,
64+
::phi::dtype::complex<double>> {
65+
HOSTDEVICE ::phi::dtype::complex<double> operator()(
66+
::phi::dtype::float8_e4m3fn in) const {
67+
return ::phi::dtype::complex<double>(static_cast<double>(in));
68+
}
69+
};
70+
71+
template <>
72+
struct CastDataTypeFunctor<::phi::dtype::float16,
73+
::phi::dtype::complex<float>> {
74+
HOSTDEVICE ::phi::dtype::complex<float> operator()(
75+
::phi::dtype::float16 in) const {
76+
return ::phi::dtype::complex<float>(static_cast<float>(in));
77+
}
78+
};
79+
80+
template <>
81+
struct CastDataTypeFunctor<::phi::dtype::float16,
82+
::phi::dtype::complex<double>> {
83+
HOSTDEVICE ::phi::dtype::complex<double> operator()(
84+
::phi::dtype::float16 in) const {
85+
return ::phi::dtype::complex<double>(static_cast<double>(in));
86+
}
87+
};
88+
89+
template <>
90+
struct CastDataTypeFunctor<::phi::dtype::bfloat16,
91+
::phi::dtype::complex<float>> {
92+
HOSTDEVICE ::phi::dtype::complex<float> operator()(
93+
::phi::dtype::bfloat16 in) const {
94+
return ::phi::dtype::complex<float>(static_cast<float>(in));
95+
}
96+
};
97+
98+
template <>
99+
struct CastDataTypeFunctor<::phi::dtype::bfloat16,
100+
::phi::dtype::complex<double>> {
101+
HOSTDEVICE ::phi::dtype::complex<double> operator()(
102+
::phi::dtype::bfloat16 in) const {
103+
return ::phi::dtype::complex<double>(static_cast<double>(in));
104+
}
105+
};
106+
35107
#if defined(PADDLE_WITH_XPU)
36108

37109
template <typename InType, typename OutType>

0 commit comments

Comments
 (0)