Skip to content

Commit 9bf079a

Browse files
committed
[Backport to 19] Update OCL nan builtin translation to work with bfloat type (#3558)
To avoid name clash between builtins with `half` and `bfloat` return types, introduce return type postfix for `bfloat nan` builtin in reverse translation to SPIR-V Friendly IR.
1 parent fd1ed03 commit 9bf079a

File tree

2 files changed

+43
-1
lines changed

2 files changed

+43
-1
lines changed

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2617,7 +2617,7 @@ class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
26172617
public:
26182618
OpenCLStdToSPIRVFriendlyIRMangleInfo(OCLExtOpKind ExtOpId,
26192619
ArrayRef<Type *> ArgTys, Type *RetTy)
2620-
: ExtOpId(ExtOpId), ArgTys(ArgTys) {
2620+
: ExtOpId(ExtOpId), ArgTys(ArgTys), RetTy(RetTy) {
26212621

26222622
std::string Postfix = "";
26232623
if (needRetTypePostfix())
@@ -2633,6 +2633,11 @@ class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
26332633
case OpenCLLIB::Vloada_halfn:
26342634
case OpenCLLIB::Vloadn:
26352635
return true;
2636+
case OpenCLLIB::Nan:
2637+
// Only add return type mangling for bfloat16 to disambiguate from half
2638+
// (both are represented as i16 in LLVM). Float and half use traditional
2639+
// naming for backward compatibility.
2640+
return RetTy->getScalarType()->isBFloatTy();
26362641
default:
26372642
return false;
26382643
}
@@ -2685,6 +2690,7 @@ class OpenCLStdToSPIRVFriendlyIRMangleInfo : public BuiltinFuncMangleInfo {
26852690
private:
26862691
OCLExtOpKind ExtOpId;
26872692
ArrayRef<Type *> ArgTys;
2693+
Type *RetTy;
26882694
};
26892695
} // namespace
26902696

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc --spirv-ext=+SPV_KHR_bfloat16 -o %t.spv
3+
; RUN: llvm-spirv %t.spv -to-text -o %t.spt
4+
; RUN: FileCheck < %t.spt %s --check-prefix=CHECK-SPIRV
5+
6+
; RUN: llvm-spirv -r %t.spv --spirv-target-env=SPV-IR -o - | llvm-dis -o %t.rev.ll
7+
; RUN: FileCheck < %t.rev.ll %s --check-prefixes=CHECK-SPV-IR
8+
9+
; Check OpenCL built-in nan translation.
10+
; Verify it's possible to distinguish between bfloat and half versions.
11+
12+
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"
13+
target triple = "spir64"
14+
15+
; CHECK-SPIRV: TypeFloat [[#BFLOAT:]] 16 0 {{$}}
16+
; CHECK-SPIRV: TypeFloat [[#HALF:]] 16 {{$}}
17+
; CHECK-SPIRV: ExtInst [[#BFLOAT]] [[#]] [[#]] nan
18+
; CHECK-SPIRV: ExtInst [[#HALF]] [[#]] [[#]] nan
19+
20+
; CHECK-SPV-IR: call spir_func bfloat @_Z22__spirv_ocl_nan_RDF16bt(
21+
; CHECK-SPV-IR: call spir_func half @_Z15__spirv_ocl_nant(
22+
23+
define dso_local spir_kernel void @test_bfloat(ptr addrspace(1) align 2 %a, i16 %b) {
24+
entry:
25+
%call = tail call spir_func bfloat @_Z23__spirv_ocl_nan__RDF16bt(i16 %b)
26+
%call2 = tail call spir_func half @_Z22__spirv_ocl_nan__Rhalft(i16 %b)
27+
ret void
28+
}
29+
30+
declare spir_func bfloat @_Z23__spirv_ocl_nan__RDF16bt(i16)
31+
declare spir_func half @_Z22__spirv_ocl_nan__Rhalft(i16)
32+
33+
34+
!opencl.ocl.version = !{!0}
35+
36+
!0 = !{i32 3, i32 0}

0 commit comments

Comments
 (0)