Skip to content

Commit b4a5392

Browse files
AlexMacleanDebadri Basak
authored andcommitted
[NVPTX] Add ex2.approx bf16 support and cleanup intrinsic definition (llvm#165446)
1 parent da92dcc commit b4a5392

File tree

10 files changed

+117
-43
lines changed

10 files changed

+117
-43
lines changed

clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -375,28 +375,28 @@ static Value *MakeCpAsync(unsigned IntrinsicID, unsigned IntrinsicIDS,
375375
CGF.EmitScalarExpr(E->getArg(1))});
376376
}
377377

378-
static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
379-
const CallExpr *E, CodeGenFunction &CGF) {
378+
static bool EnsureNativeHalfSupport(unsigned BuiltinID, const CallExpr *E,
379+
CodeGenFunction &CGF) {
380380
auto &C = CGF.CGM.getContext();
381-
if (!(C.getLangOpts().NativeHalfType ||
382-
!C.getTargetInfo().useFP16ConversionIntrinsics())) {
381+
if (!C.getLangOpts().NativeHalfType &&
382+
C.getTargetInfo().useFP16ConversionIntrinsics()) {
383383
CGF.CGM.Error(E->getExprLoc(), C.BuiltinInfo.getQuotedName(BuiltinID) +
384384
" requires native half type support.");
385-
return nullptr;
385+
return false;
386386
}
387+
return true;
388+
}
387389

388-
if (BuiltinID == NVPTX::BI__nvvm_ldg_h || BuiltinID == NVPTX::BI__nvvm_ldg_h2)
389-
return MakeLdg(CGF, E);
390-
391-
if (IntrinsicID == Intrinsic::nvvm_ldu_global_f)
392-
return MakeLdu(IntrinsicID, CGF, E);
390+
static Value *MakeHalfType(Function *Intrinsic, unsigned BuiltinID,
391+
const CallExpr *E, CodeGenFunction &CGF) {
392+
if (!EnsureNativeHalfSupport(BuiltinID, E, CGF))
393+
return nullptr;
393394

394395
SmallVector<Value *, 16> Args;
395-
auto *F = CGF.CGM.getIntrinsic(IntrinsicID);
396-
auto *FTy = F->getFunctionType();
396+
auto *FTy = Intrinsic->getFunctionType();
397397
unsigned ICEArguments = 0;
398398
ASTContext::GetBuiltinTypeError Error;
399-
C.GetBuiltinType(BuiltinID, Error, &ICEArguments);
399+
CGF.CGM.getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments);
400400
assert(Error == ASTContext::GE_None && "Should not codegen an error");
401401
for (unsigned i = 0, e = E->getNumArgs(); i != e; ++i) {
402402
assert((ICEArguments & (1 << i)) == 0);
@@ -407,8 +407,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
407407
Args.push_back(ArgValue);
408408
}
409409

410-
return CGF.Builder.CreateCall(F, Args);
410+
return CGF.Builder.CreateCall(Intrinsic, Args);
411411
}
412+
413+
static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
414+
const CallExpr *E, CodeGenFunction &CGF) {
415+
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
416+
}
417+
412418
} // namespace
413419

414420
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -913,9 +919,14 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
913919
}
914920
// The following builtins require half type support
915921
case NVPTX::BI__nvvm_ex2_approx_f16:
916-
return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16, BuiltinID, E, *this);
922+
return MakeHalfType(
923+
CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx, Builder.getHalfTy()),
924+
BuiltinID, E, *this);
917925
case NVPTX::BI__nvvm_ex2_approx_f16x2:
918-
return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16x2, BuiltinID, E, *this);
926+
return MakeHalfType(
927+
CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx,
928+
FixedVectorType::get(Builder.getHalfTy(), 2)),
929+
BuiltinID, E, *this);
919930
case NVPTX::BI__nvvm_ff2f16x2_rn:
920931
return MakeHalfType(Intrinsic::nvvm_ff2f16x2_rn, BuiltinID, E, *this);
921932
case NVPTX::BI__nvvm_ff2f16x2_rn_relu:
@@ -1049,12 +1060,22 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
10491060
case NVPTX::BI__nvvm_fabs_d:
10501061
return Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
10511062
EmitScalarExpr(E->getArg(0)));
1063+
case NVPTX::BI__nvvm_ex2_approx_d:
1064+
case NVPTX::BI__nvvm_ex2_approx_f:
1065+
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx,
1066+
EmitScalarExpr(E->getArg(0)));
1067+
case NVPTX::BI__nvvm_ex2_approx_ftz_f:
1068+
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx_ftz,
1069+
EmitScalarExpr(E->getArg(0)));
10521070
case NVPTX::BI__nvvm_ldg_h:
10531071
case NVPTX::BI__nvvm_ldg_h2:
1054-
return MakeHalfType(Intrinsic::not_intrinsic, BuiltinID, E, *this);
1072+
return EnsureNativeHalfSupport(BuiltinID, E, *this) ? MakeLdg(*this, E)
1073+
: nullptr;
10551074
case NVPTX::BI__nvvm_ldu_h:
10561075
case NVPTX::BI__nvvm_ldu_h2:
1057-
return MakeHalfType(Intrinsic::nvvm_ldu_global_f, BuiltinID, E, *this);
1076+
return EnsureNativeHalfSupport(BuiltinID, E, *this)
1077+
? MakeLdu(Intrinsic::nvvm_ldu_global_f, *this, E)
1078+
: nullptr;
10581079
case NVPTX::BI__nvvm_cp_async_ca_shared_global_4:
10591080
return MakeCpAsync(Intrinsic::nvvm_cp_async_ca_shared_global_4,
10601081
Intrinsic::nvvm_cp_async_ca_shared_global_4_s, *this, E,

clang/test/CodeGen/builtins-nvptx-native-half-type-native.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2)));
99

1010
// CHECK: call half @llvm.nvvm.ex2.approx.f16(half {{.*}})
11-
// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> {{.*}})
11+
// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> {{.*}})
1212
// CHECK: call half @llvm.nvvm.fma.rn.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
1313
// CHECK: call half @llvm.nvvm.fma.rn.ftz.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
1414
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})

clang/test/CodeGen/builtins-nvptx-native-half-type.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ __device__ void nvvm_ex2_sm75() {
4141
#if __CUDA_ARCH__ >= 750
4242
// CHECK_PTX70_SM75: call half @llvm.nvvm.ex2.approx.f16
4343
__nvvm_ex2_approx_f16(0.1f16);
44-
// CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.f16x2
44+
// CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.v2f16
4545
__nvvm_ex2_approx_f16x2({0.1f16, 0.7f16});
4646
#endif
4747
// CHECK: ret void

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,15 +1334,8 @@ let TargetPrefix = "nvvm" in {
13341334
//
13351335
let IntrProperties = [IntrNoMem] in {
13361336
foreach ftz = ["", "_ftz"] in
1337-
def int_nvvm_ex2_approx # ftz # _f : NVVMBuiltin,
1338-
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty]>;
1339-
1340-
def int_nvvm_ex2_approx_d : NVVMBuiltin,
1341-
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty]>;
1342-
def int_nvvm_ex2_approx_f16 :
1343-
DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty]>;
1344-
def int_nvvm_ex2_approx_f16x2 :
1345-
DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty]>;
1337+
def int_nvvm_ex2_approx # ftz :
1338+
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
13461339

13471340
foreach ftz = ["", "_ftz"] in
13481341
def int_nvvm_lg2_approx # ftz # _f : NVVMBuiltin,

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1504,6 +1504,10 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
15041504
else if (Name.consume_front("fabs."))
15051505
// nvvm.fabs.{f,ftz.f,d}
15061506
Expand = Name == "f" || Name == "ftz.f" || Name == "d";
1507+
else if (Name.consume_front("ex2.approx."))
1508+
// nvvm.ex2.approx.{f,ftz.f,d,f16x2}
1509+
Expand =
1510+
Name == "f" || Name == "ftz.f" || Name == "d" || Name == "f16x2";
15071511
else if (Name.consume_front("max.") || Name.consume_front("min."))
15081512
// nvvm.{min,max}.{i,ii,ui,ull}
15091513
Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
@@ -2550,6 +2554,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
25502554
Intrinsic::ID IID = (Name == "fabs.ftz.f") ? Intrinsic::nvvm_fabs_ftz
25512555
: Intrinsic::nvvm_fabs;
25522556
Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
2557+
} else if (Name.consume_front("ex2.approx.")) {
2558+
// nvvm.ex2.approx.{f,ftz.f,d,f16x2}
2559+
Intrinsic::ID IID = Name.starts_with("ftz") ? Intrinsic::nvvm_ex2_approx_ftz
2560+
: Intrinsic::nvvm_ex2_approx;
2561+
Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
25532562
} else if (Name.starts_with("atomic.load.add.f32.p") ||
25542563
Name.starts_with("atomic.load.add.f64.p")) {
25552564
Value *Ptr = CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1562,12 +1562,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>;
15621562
// Exp2 Log2
15631563
//
15641564

1565-
def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>;
1566-
def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>;
1565+
def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>;
1566+
def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>;
15671567

15681568
let Predicates = [hasPTX<70>, hasSM<75>] in {
1569-
def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>;
1570-
def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>;
1569+
def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>;
1570+
def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>;
1571+
}
1572+
1573+
let Predicates = [hasPTX<78>, hasSM<90>] in {
1574+
def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>;
1575+
def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>;
15711576
}
15721577

15731578
def LG2_APPROX_f32 :

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
318318
// answer. These include:
319319
//
320320
// - nvvm_cos_approx_{f,ftz_f}
321-
// - nvvm_ex2_approx_{d,f,ftz_f}
321+
// - nvvm_ex2_approx(_ftz)
322322
// - nvvm_lg2_approx_{d,f,ftz_f}
323323
// - nvvm_sin_approx_{f,ftz_f}
324324
// - nvvm_sqrt_approx_{f,ftz_f}

llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ declare void @llvm.nvvm.barrier(i32, i32)
8787
declare void @llvm.nvvm.barrier.sync(i32)
8888
declare void @llvm.nvvm.barrier.sync.cnt(i32, i32)
8989

90+
declare float @llvm.nvvm.ex2.approx.f(float)
91+
declare double @llvm.nvvm.ex2.approx.d(double)
92+
declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
93+
declare float @llvm.nvvm.ex2.approx.ftz.f(float)
94+
9095
; CHECK-LABEL: @simple_upgrade
9196
define void @simple_upgrade(i32 %a, i64 %b, i16 %c) {
9297
; CHECK: call i32 @llvm.bitreverse.i32(i32 %a)
@@ -355,3 +360,15 @@ define void @cta_barriers(i32 %x, i32 %y) {
355360
call void @llvm.nvvm.barrier.sync.cnt(i32 %x, i32 %y)
356361
ret void
357362
}
363+
364+
define void @nvvm_ex2_approx(float %a, double %b, half %c, <2 x half> %d) {
365+
; CHECK: call float @llvm.nvvm.ex2.approx.f32(float %a)
366+
; CHECK: call double @llvm.nvvm.ex2.approx.f64(double %b)
367+
; CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %d)
368+
; CHECK: call float @llvm.nvvm.ex2.approx.ftz.f32(float %a)
369+
%r1 = call float @llvm.nvvm.ex2.approx.f(float %a)
370+
%r2 = call double @llvm.nvvm.ex2.approx.d(double %b)
371+
%r3 = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %d)
372+
%r4 = call float @llvm.nvvm.ex2.approx.ftz.f(float %a)
373+
ret void
374+
}

llvm/test/CodeGen/NVPTX/f16-ex2.ll

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2-
; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK-FP16 %s
3-
; RUN: %if ptxas-sm_75 && ptxas-isa-7.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
2+
; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK-FP16 %s
3+
; RUN: %if ptxas-sm_90 && ptxas-isa-7.8 %{ llc < %s -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
44
target triple = "nvptx64-nvidia-cuda"
55

66
declare half @llvm.nvvm.ex2.approx.f16(half)
7-
declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
7+
declare <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half>)
8+
declare bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat)
9+
declare <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat>)
810

9-
; CHECK-LABEL: ex2_half
1011
define half @ex2_half(half %0) {
1112
; CHECK-FP16-LABEL: ex2_half(
1213
; CHECK-FP16: {
@@ -21,7 +22,6 @@ define half @ex2_half(half %0) {
2122
ret half %res
2223
}
2324

24-
; CHECK-LABEL: ex2_2xhalf
2525
define <2 x half> @ex2_2xhalf(<2 x half> %0) {
2626
; CHECK-FP16-LABEL: ex2_2xhalf(
2727
; CHECK-FP16: {
@@ -32,6 +32,34 @@ define <2 x half> @ex2_2xhalf(<2 x half> %0) {
3232
; CHECK-FP16-NEXT: ex2.approx.f16x2 %r2, %r1;
3333
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
3434
; CHECK-FP16-NEXT: ret;
35-
%res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0)
35+
%res = call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %0)
3636
ret <2 x half> %res
3737
}
38+
39+
define bfloat @ex2_bfloat(bfloat %0) {
40+
; CHECK-FP16-LABEL: ex2_bfloat(
41+
; CHECK-FP16: {
42+
; CHECK-FP16-NEXT: .reg .b16 %rs<3>;
43+
; CHECK-FP16-EMPTY:
44+
; CHECK-FP16-NEXT: // %bb.0:
45+
; CHECK-FP16-NEXT: ld.param.b16 %rs1, [ex2_bfloat_param_0];
46+
; CHECK-FP16-NEXT: ex2.approx.ftz.bf16 %rs2, %rs1;
47+
; CHECK-FP16-NEXT: st.param.b16 [func_retval0], %rs2;
48+
; CHECK-FP16-NEXT: ret;
49+
%res = call bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat %0)
50+
ret bfloat %res
51+
}
52+
53+
define <2 x bfloat> @ex2_2xbfloat(<2 x bfloat> %0) {
54+
; CHECK-FP16-LABEL: ex2_2xbfloat(
55+
; CHECK-FP16: {
56+
; CHECK-FP16-NEXT: .reg .b32 %r<3>;
57+
; CHECK-FP16-EMPTY:
58+
; CHECK-FP16-NEXT: // %bb.0:
59+
; CHECK-FP16-NEXT: ld.param.b32 %r1, [ex2_2xbfloat_param_0];
60+
; CHECK-FP16-NEXT: ex2.approx.ftz.bf16x2 %r2, %r1;
61+
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
62+
; CHECK-FP16-NEXT: ret;
63+
%res = call <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat> %0)
64+
ret <2 x bfloat> %res
65+
}

llvm/test/CodeGen/NVPTX/f32-ex2.ll

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
; RUN: %if ptxas-sm_50 && ptxas-isa-3.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 -mattr=+ptx32 | %ptxas-verify -arch=sm_50 %}
44
target triple = "nvptx-nvidia-cuda"
55

6-
declare float @llvm.nvvm.ex2.approx.f(float)
6+
declare float @llvm.nvvm.ex2.approx.f32(float)
7+
declare float @llvm.nvvm.ex2.approx.ftz.f32(float)
78

89
; CHECK-LABEL: ex2_float
910
define float @ex2_float(float %0) {
@@ -16,7 +17,7 @@ define float @ex2_float(float %0) {
1617
; CHECK-NEXT: ex2.approx.f32 %r2, %r1;
1718
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
1819
; CHECK-NEXT: ret;
19-
%res = call float @llvm.nvvm.ex2.approx.f(float %0)
20+
%res = call float @llvm.nvvm.ex2.approx.f32(float %0)
2021
ret float %res
2122
}
2223

@@ -31,6 +32,6 @@ define float @ex2_float_ftz(float %0) {
3132
; CHECK-NEXT: ex2.approx.ftz.f32 %r2, %r1;
3233
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
3334
; CHECK-NEXT: ret;
34-
%res = call float @llvm.nvvm.ex2.approx.ftz.f(float %0)
35+
%res = call float @llvm.nvvm.ex2.approx.ftz.f32(float %0)
3536
ret float %res
3637
}

0 commit comments

Comments
 (0)