Skip to content

Commit 2b9441b

Browse files
author
Hugh Delaney
committed
Add patterns for fma.relu.{f16|bf16}
Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.
1 parent 05b6c2e commit 2b9441b

File tree

2 files changed

+96
-0
lines changed

2 files changed

+96
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3917,3 +3917,22 @@ def atomic_thread_fence_seq_cst_cta :
39173917
def atomic_thread_fence_acq_rel_cta :
39183918
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
39193919
Requires<[hasPTX<60>, hasSM<70>]>;
3920+
3921+
def fpimm0 : FPImmLeaf<fAny, [{
3922+
return Imm.isExactlyValue(+0.0);
3923+
}]>;
3924+
3925+
def FMARELU_F16 :
3926+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3927+
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>;
3928+
def FMARELU_BF16 :
3929+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3930+
"fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>;
3931+
3932+
def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3933+
(FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3934+
Requires<[useFP16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
3935+
3936+
def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3937+
(FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3938+
Requires<[hasBF16Math, allowFMA, allowUnsafeFPMath, hasPTX<60>, hasSM<70>]>;
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc < %s -march=nvptx64 --enable-unsafe-fp-math -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | FileCheck %s
3+
; RUN: %if ptxas %{ llc < %s -march=nvptx64 -mcpu=sm_80 -mattr=ptx70 -verify-machineinstrs -fp-contract=fast -nvptx-fma-level=2 | %ptxas-verify -arch=sm_80 %}
4+
5+
define half @fma_f16(half %a, half %b, half %c) {
6+
; CHECK-LABEL: fma_f16(
7+
; CHECK: {
8+
; CHECK-NEXT: .reg .b16 %rs<5>;
9+
; CHECK-EMPTY:
10+
; CHECK-NEXT: // %bb.0:
11+
; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_param_0];
12+
; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_param_1];
13+
; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_param_2];
14+
; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
15+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
16+
; CHECK-NEXT: ret;
17+
%1 = call half @llvm.fma.f16(half %a, half %b, half %c)
18+
%2 = fcmp ogt half %1, 0.0
19+
%3 = select i1 %2, half %1, half 0.0
20+
ret half %3
21+
}
22+
23+
define half @fma_f16_expanded(half %a, half %b, half %c) {
24+
; CHECK-LABEL: fma_f16_expanded(
25+
; CHECK: {
26+
; CHECK-NEXT: .reg .b16 %rs<5>;
27+
; CHECK-EMPTY:
28+
; CHECK-NEXT: // %bb.0:
29+
; CHECK-NEXT: ld.param.b16 %rs1, [fma_f16_expanded_param_0];
30+
; CHECK-NEXT: ld.param.b16 %rs2, [fma_f16_expanded_param_1];
31+
; CHECK-NEXT: ld.param.b16 %rs3, [fma_f16_expanded_param_2];
32+
; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
33+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
34+
; CHECK-NEXT: ret;
35+
%1 = fmul half %a, %b
36+
%2 = fadd half %1, %c
37+
%3 = fcmp ogt half %2, 0.0
38+
%4 = select i1 %3, half %2, half 0.0
39+
ret half %4
40+
}
41+
42+
define bfloat @fma_bf16(bfloat %a, bfloat %b, bfloat %c) {
43+
; CHECK-LABEL: fma_bf16(
44+
; CHECK: {
45+
; CHECK-NEXT: .reg .b16 %rs<5>;
46+
; CHECK-EMPTY:
47+
; CHECK-NEXT: // %bb.0:
48+
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_param_0];
49+
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_param_1];
50+
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_param_2];
51+
; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
52+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
53+
; CHECK-NEXT: ret;
54+
%1 = call bfloat @llvm.fma.bf16(bfloat %a, bfloat %b, bfloat %c)
55+
%2 = fcmp ogt bfloat %1, 0.0
56+
%3 = select i1 %2, bfloat %1, bfloat 0.0
57+
ret bfloat %3
58+
}
59+
60+
define bfloat @fma_bf16_expanded(bfloat %a, bfloat %b, bfloat %c) {
61+
; CHECK-LABEL: fma_bf16_expanded(
62+
; CHECK: {
63+
; CHECK-NEXT: .reg .b16 %rs<5>;
64+
; CHECK-EMPTY:
65+
; CHECK-NEXT: // %bb.0:
66+
; CHECK-NEXT: ld.param.b16 %rs1, [fma_bf16_expanded_param_0];
67+
; CHECK-NEXT: ld.param.b16 %rs2, [fma_bf16_expanded_param_1];
68+
; CHECK-NEXT: ld.param.b16 %rs3, [fma_bf16_expanded_param_2];
69+
; CHECK-NEXT: fma.rn.relu %rs4, %rs1, %rs2, %rs3;
70+
; CHECK-NEXT: st.param.b16 [func_retval0], %rs4;
71+
; CHECK-NEXT: ret;
72+
%1 = fmul bfloat %a, %b
73+
%2 = fadd bfloat %1, %c
74+
%3 = fcmp ogt bfloat %2, 0.0
75+
%4 = select i1 %3, bfloat %2, bfloat 0.0
76+
ret bfloat %4
77+
}

0 commit comments

Comments
 (0)