Skip to content

Commit f5eea93

Browse files
author
Hugh Delaney
committed
Add patterns for fma.relu.{f16|bf16}
Add patterns to lower fma(a, b, c) > 0 ? fma(a, b, c) : 0 for f16 and bf16 types.
1 parent 05b6c2e commit f5eea93

File tree

2 files changed

+957
-0
lines changed

2 files changed

+957
-0
lines changed

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3917,3 +3917,40 @@ def atomic_thread_fence_seq_cst_cta :
39173917
def atomic_thread_fence_acq_rel_cta :
39183918
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
39193919
Requires<[hasPTX<60>, hasSM<70>]>;
3920+
3921+
def fpimm0 : FPImmLeaf<fAny, [{
3922+
return Imm.isExactlyValue(+0.0);
3923+
}]>;
3924+
3925+
def FMARELU_F16 :
3926+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3927+
"fma.rn.relu.f16 \t$dst, $a, $b, $c;", []>,
3928+
Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
3929+
def FMARELU_BF16 :
3930+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3931+
"fma.rn.relu.bf16 \t$dst, $a, $b, $c;", []>,
3932+
Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
3933+
def FMARELU_F16_FTZ :
3934+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3935+
"fma.rn.relu.ftz.f16 \t$dst, $a, $b, $c;", []>,
3936+
Requires<[useFP16Math, hasPTX<70>, hasSM<80>]>;
3937+
def FMARELU_BF16_FTZ :
3938+
NVPTXInst<(outs Int16Regs:$dst), (ins Int16Regs:$a, Int16Regs:$b, Int16Regs:$c),
3939+
"fma.rn.relu.ftz.bf16 \t$dst, $a, $b, $c;", []>,
3940+
Requires<[hasBF16Math, hasPTX<70>, hasSM<80>]>;
3941+
3942+
3943+
// FTZ variants
3944+
def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3945+
(FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3946+
Requires<[doF32FTZ, allowUnsafeFPMath]>;
3947+
def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3948+
(FMARELU_BF16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3949+
Requires<[doF32FTZ, allowUnsafeFPMath]>;
3950+
// No FTZ
3951+
def : Pat<(f16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3952+
(FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3953+
Requires<[allowUnsafeFPMath]>;
3954+
def : Pat<(bf16 (fmaxnum (fma Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm0)),
3955+
(FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
3956+
Requires<[allowUnsafeFPMath]>;

0 commit comments

Comments
 (0)