Skip to content

Commit 4aff56e

Browse files
mjulian31mahesh-attarde
authored andcommitted
[LLVM][NVPTX] Upstream tanh intrinsic for libdevice (llvm#149596)
Currently __nv_fast_tanhf() in libdevice maps to an nvvm intrinsic that has not been upstreamed, which is causing issues when using the NVPTX backend from upstream. Instead of upstreaming the intrinsic, we can instead use the existing Intrinsic::tanh with the afn flag. This change adds NVPTX backend support for ISD::TANH, adds auto-upgrade for the old tanh_approx intrinsic to @llvm.tanh.f32 with afn flag so that libdevice works properly upstream, and adds a basic codegen test and a case to the auto-upgrade test.
1 parent 5db0584 commit 4aff56e

File tree

5 files changed

+66
-3
lines changed

5 files changed

+66
-3
lines changed

llvm/lib/IR/AutoUpgrade.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,6 +1450,7 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
14501450
.Case("popc.ll", true)
14511451
.Case("h2f", true)
14521452
.Case("swap.lo.hi.b64", true)
1453+
.Case("tanh.approx.f32", true)
14531454
.Default(false);
14541455

14551456
if (Expand) {
@@ -2543,6 +2544,12 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
25432544
MDNode *MD = MDNode::get(Builder.getContext(), {});
25442545
LD->setMetadata(LLVMContext::MD_invariant_load, MD);
25452546
return LD;
2547+
} else if (Name == "tanh.approx.f32") {
2548+
// nvvm.tanh.approx.f32 -> afn llvm.tanh.f32
2549+
FastMathFlags FMF;
2550+
FMF.setApproxFunc();
2551+
Rep = Builder.CreateUnaryIntrinsic(Intrinsic::tanh, CI->getArgOperand(0),
2552+
FMF);
25462553
} else if (Name == "barrier0" || Name == "barrier.n" || Name == "bar.sync") {
25472554
Value *Arg =
25482555
Name.ends_with('0') ? Builder.getInt32(0) : CI->getArgOperand(0);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -952,10 +952,13 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
952952
// promoted to f32. v2f16 is expanded to f16, which is then promoted
953953
// to f32.
954954
for (const auto &Op :
955-
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS}) {
955+
{ISD::FDIV, ISD::FREM, ISD::FSQRT, ISD::FSIN, ISD::FCOS, ISD::FTANH}) {
956956
setOperationAction(Op, MVT::f16, Promote);
957957
setOperationAction(Op, MVT::f32, Legal);
958-
setOperationAction(Op, MVT::f64, Legal);
958+
// only div/rem/sqrt are legal for f64
959+
if (Op == ISD::FDIV || Op == ISD::FREM || Op == ISD::FSQRT) {
960+
setOperationAction(Op, MVT::f64, Legal);
961+
}
959962
setOperationAction(Op, {MVT::v2f16, MVT::v2bf16, MVT::v2f32}, Expand);
960963
setOperationAction(Op, MVT::bf16, Promote);
961964
AddPromotedToType(Op, MVT::bf16, MVT::f32);

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ defm FMA_F32 : FMA<F32RT, allow_ftz = true>;
12341234
defm FMA_F32x2 : FMA<F32X2RT, allow_ftz = true, preds = [hasF32x2Instructions]>;
12351235
defm FMA_F64 : FMA<F64RT, allow_ftz = false>;
12361236

1237-
// sin/cos
1237+
// sin/cos/tanh
12381238

12391239
class UnaryOpAllowsApproxFn<SDPatternOperator operator>
12401240
: PatFrag<(ops node:$A),
@@ -1250,6 +1250,10 @@ def COS_APPROX_f32 :
12501250
BasicFlagsNVPTXInst<(outs B32:$dst), (ins B32:$src), (ins FTZFlag:$ftz),
12511251
"cos.approx$ftz.f32",
12521252
[(set f32:$dst, (UnaryOpAllowsApproxFn<fcos> f32:$src))]>;
1253+
def TANH_APPROX_f32 :
1254+
BasicNVPTXInst<(outs B32:$dst), (ins B32:$src), "tanh.approx.f32",
1255+
[(set f32:$dst, (UnaryOpAllowsApproxFn<ftanh> f32:$src))]>,
1256+
Requires<[hasPTX<70>, hasSM<75>]>;
12531257

12541258
//-----------------------------------
12551259
// Bitwise operations

llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ declare float @llvm.nvvm.fabs.f(float)
1717
declare float @llvm.nvvm.fabs.ftz.f(float)
1818
declare double @llvm.nvvm.fabs.d(double)
1919

20+
declare float @llvm.nvvm.tanh.approx.f32(float)
21+
2022
declare i16 @llvm.nvvm.max.s(i16, i16)
2123
declare i32 @llvm.nvvm.max.i(i32, i32)
2224
declare i64 @llvm.nvvm.max.ll(i64, i64)
@@ -138,6 +140,13 @@ define void @fabs(float %a, double %b) {
138140
ret void
139141
}
140142

143+
; CHECK-LABEL: @tanh
144+
define void @tanh(float %a) {
145+
; CHECK: call afn float @llvm.tanh.f32(float %a)
146+
%r1 = call float @llvm.nvvm.tanh.approx.f32(float %a)
147+
ret void
148+
}
149+
141150
; CHECK-LABEL: @min_max
142151
define void @min_max(i16 %a1, i16 %a2, i32 %b1, i32 %b2, i64 %c1, i64 %c2) {
143152
; CHECK: [[maxs:%[a-zA-Z0-9.]+]] = icmp sge i16 %a1, %a2

llvm/test/CodeGen/NVPTX/tanhf.ll

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck %s
3+
; RUN: %if ptxas-11.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
4+
5+
target triple = "nvptx64-nvidia-cuda"
6+
7+
define float @test1(float %in) local_unnamed_addr {
8+
; CHECK-LABEL: test1(
9+
; CHECK: {
10+
; CHECK-NEXT: .reg .b32 %r<3>;
11+
; CHECK-EMPTY:
12+
; CHECK-NEXT: // %bb.0:
13+
; CHECK-NEXT: ld.param.b32 %r1, [test1_param_0];
14+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
15+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
16+
; CHECK-NEXT: ret;
17+
%call = call afn float @llvm.tanh.f32(float %in)
18+
ret float %call
19+
}
20+
21+
define half @test2(half %in) local_unnamed_addr {
22+
; CHECK-LABEL: test2(
23+
; CHECK: {
24+
; CHECK-NEXT: .reg .b16 %rs<3>;
25+
; CHECK-NEXT: .reg .b32 %r<3>;
26+
; CHECK-EMPTY:
27+
; CHECK-NEXT: // %bb.0:
28+
; CHECK-NEXT: ld.param.b16 %rs1, [test2_param_0];
29+
; CHECK-NEXT: cvt.f32.f16 %r1, %rs1;
30+
; CHECK-NEXT: tanh.approx.f32 %r2, %r1;
31+
; CHECK-NEXT: cvt.rn.f16.f32 %rs2, %r2;
32+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs2;
33+
; CHECK-NEXT: ret;
34+
%call = call afn half @llvm.tanh.f16(half %in)
35+
ret half %call
36+
}
37+
38+
declare float @llvm.tanh.f32(float)
39+
declare half @llvm.tanh.f16(half)
40+

0 commit comments

Comments
 (0)