Skip to content

Commit f29f693

Browse files
authored
[Bug Fix] Support isfinite/isnan/isinf for float16/bfloat16 on CUDA/HIP (PaddlePaddle#75933)
- 在 isfinite_kernel_impl.h 的 GPU 侧 `Isfinite/Isnan/Isinf` 核函数里,把 “通用浮点” 模板拆成两支:一支只接受标准 `float/double`,另一支专门匹配 `phi::float16` 和 `phi::bfloat16`。这避免了 `std::is_floating_point` 对这两种自定义半精度类型返回 `false` 而导致完全没有匹配内核的情况,从而补齐了半精度在 CUDA/HIP 上的 `isfinite/isnan/isinf` 支持。 - 由于有了独立分支,调用的仍是对应的 `isfinite/isnan/isinf` 设备实现,逻辑保持一致,但现在 `float16/bfloat16` 会正确走到实际内核里,不再出现链接缺符号或运行时报 “未注册该数据类型” 的问题。 - 去掉三个模板 `IsfiniteKernel/IsinfKernel/IsnanKernel` 的 `PADDLE_API` 修饰,避免在头文件模板定义上做符号导出,引起重复导出或 Windows 下的装饰冲突。
1 parent fbe99bc commit f29f693

File tree

1 file changed

+54
-6
lines changed

1 file changed

+54
-6
lines changed

paddle/phi/kernels/impl/isfinite_kernel_impl.h

Lines changed: 54 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,23 @@ __global__ void IsfiniteCUDAKernel(
301301
const T* in_data,
302302
IndexType num,
303303
bool* out_data,
304-
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
304+
typename std::enable_if<std::is_floating_point<T>::value &&
305+
!std::is_same<T, phi::bfloat16>::value &&
306+
!std::is_same<T, phi::float16>::value>::type* = 0) {
307+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
308+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
309+
const T& a = in_data[i];
310+
out_data[i] = isfinite(a);
311+
}
312+
}
313+
314+
template <typename T, typename IndexType>
315+
__global__ void IsfiniteCUDAKernel(
316+
const T* in_data,
317+
IndexType num,
318+
bool* out_data,
319+
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
320+
std::is_same<T, phi::float16>::value>::type* = 0) {
305321
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
306322
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
307323
const T& a = in_data[i];
@@ -340,7 +356,23 @@ __global__ void IsnanCUDAKernel(
340356
const T* in_data,
341357
IndexType num,
342358
bool* out_data,
343-
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
359+
typename std::enable_if<std::is_floating_point<T>::value &&
360+
!std::is_same<T, phi::bfloat16>::value &&
361+
!std::is_same<T, phi::float16>::value>::type* = 0) {
362+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
363+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
364+
const T& a = in_data[i];
365+
out_data[i] = isnan(a);
366+
}
367+
}
368+
369+
template <typename T, typename IndexType>
370+
__global__ void IsnanCUDAKernel(
371+
const T* in_data,
372+
IndexType num,
373+
bool* out_data,
374+
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
375+
std::is_same<T, phi::float16>::value>::type* = 0) {
344376
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
345377
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
346378
const T& a = in_data[i];
@@ -379,7 +411,23 @@ __global__ void IsinfCUDAKernel(
379411
const T* in_data,
380412
IndexType num,
381413
bool* out_data,
382-
typename std::enable_if<std::is_floating_point<T>::value>::type* = 0) {
414+
typename std::enable_if<std::is_floating_point<T>::value &&
415+
!std::is_same<T, phi::bfloat16>::value &&
416+
!std::is_same<T, phi::float16>::value>::type* = 0) {
417+
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
418+
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
419+
const T& a = in_data[i];
420+
out_data[i] = isinf(a);
421+
}
422+
}
423+
424+
template <typename T, typename IndexType>
425+
__global__ void IsinfCUDAKernel(
426+
const T* in_data,
427+
IndexType num,
428+
bool* out_data,
429+
typename std::enable_if<std::is_same<T, phi::bfloat16>::value ||
430+
std::is_same<T, phi::float16>::value>::type* = 0) {
383431
IndexType idx = threadIdx.x + blockIdx.x * blockDim.x;
384432
for (IndexType i = idx; i < num; i += blockDim.x * gridDim.x) {
385433
const T& a = in_data[i];
@@ -477,9 +525,9 @@ struct IsinfFunctor<phi::GPUContext, T> {
477525
#endif
478526

479527
template <typename T, typename Context>
480-
PADDLE_API void IsfiniteKernel(const Context& dev_ctx,
481-
const DenseTensor& x,
482-
DenseTensor* out) {
528+
void IsfiniteKernel(const Context& dev_ctx,
529+
const DenseTensor& x,
530+
DenseTensor* out) {
483531
if (out && out->numel() == 0) {
484532
dev_ctx.template Alloc<bool>(out);
485533
return;

0 commit comments

Comments
 (0)