Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 51 additions & 14 deletions llvm/lib/Target/NVPTX/NVPTXInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -418,25 +418,13 @@ multiclass F3_fma_component<string OpcStr, SDNode OpNode> {
!strconcat(OpcStr, ".f16x2 \t$dst, $a, $b;"),
[(set Int32Regs:$dst, (OpNode (v2f16 Int32Regs:$a), (v2f16 Int32Regs:$b)))]>,
Requires<[useFP16Math, allowFMA]>;
def bf16rr_ftz :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b),
!strconcat(OpcStr, ".ftz.bf16 \t$dst, $a, $b;"),
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
def bf16rr :
NVPTXInst<(outs Int16Regs:$dst),
(ins Int16Regs:$a, Int16Regs:$b),
!strconcat(OpcStr, ".bf16 \t$dst, $a, $b;"),
[(set Int16Regs:$dst, (OpNode (bf16 Int16Regs:$a), (bf16 Int16Regs:$b)))]>,
Requires<[hasBF16Math, allowFMA]>;

def bf16x2rr_ftz :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b),
!strconcat(OpcStr, ".ftz.bf16x2 \t$dst, $a, $b;"),
[(set (v2bf16 Int32Regs:$dst), (OpNode (v2bf16 Int32Regs:$a), (v2bf16 Int32Regs:$b)))]>,
Requires<[hasBF16Math, allowFMA, doF32FTZ]>;
def bf16x2rr :
NVPTXInst<(outs Int32Regs:$dst),
(ins Int32Regs:$a, Int32Regs:$b),
Expand Down Expand Up @@ -1423,9 +1411,7 @@ defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>;
defm BFMA16_ftz : FMA_BF16<"fma.rn.ftz.bf16", bf16, Int16Regs, doF32FTZ>;
defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
defm BFMA16x2_ftz : FMA_BF16<"fma.rn.ftz.bf16x2", v2bf16, Int32Regs, doF32FTZ>;
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
Expand Down Expand Up @@ -3959,3 +3945,54 @@ def atomic_thread_fence_seq_cst_cta :
def atomic_thread_fence_acq_rel_cta :
NVPTXInst<(outs), (ins), "fence.acq_rel.cta;", []>,
Requires<[hasPTX<60>, hasSM<70>]>;

def fpimm_any_zero : FPImmLeaf<fAny, [{
return Imm.isZero();
}]>;

def fpimm_positive_zero_v2f16 : PatFrag<(ops), (v2f16 (bitconvert (i32 0)))>;
def fpimm_positive_zero_v2bf16 : PatFrag<(ops), (v2bf16 (bitconvert (i32 0)))>;

// Perform substitution if fma only has one use, and also if instruction has
// nnan instruction flag or if the TM has NoNaNsFPMath
def NVPTX_fma_oneuse_and_nnan : PatFrag<(ops node:$a, node:$b, node:$c),
(fma node:$a, node:$b, node:$c), [{
return N->hasOneUse() &&
(N->getFlags().hasNoNaNs() || TM.Options.NoNaNsFPMath);
}]>;
// fmaxnum will differentiate between signed and unsigned zeros soon, so this
// PatFrag is for a fmaxnum node with nsz
def NVPTX_fmaxnum_nsz : PatFrag<(ops node:$a, node:$b),
(fmaxnum node:$a, node:$b), [{
return N->getFlags().hasNoSignedZeros() || TM.Options.NoSignedZerosFPMath;
}]>;

class NVPTXInst_rrr<RegisterClass RC, string Instruction, list<Predicate> Preds>
: NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
!strconcat(Instruction, "\t$dst, $a, $b, $c;"), []>,
Requires<Preds>;

def FMARELU_F16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
def FMARELU_F16_FTZ : NVPTXInst_rrr<Int16Regs, "fma.rn.ftz.relu.f16", [useFP16Math, hasPTX<70>, hasSM<80>]>;
def FMARELU_BF16 : NVPTXInst_rrr<Int16Regs, "fma.rn.relu.bf16", [hasBF16Math, hasPTX<70>, hasSM<80>]>;
def FMARELU_F16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
def FMARELU_F16X2_FTZ : NVPTXInst_rrr<Int32Regs, "fma.rn.ftz.relu.f16x2", [useFP16Math, hasPTX<70>, hasSM<80>]>;
def FMARELU_BF16X2 : NVPTXInst_rrr<Int32Regs, "fma.rn.relu.bf16x2", [hasBF16Math, hasPTX<70>, hasSM<80>]>;

// FTZ
def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
(FMARELU_F16_FTZ Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>,
Requires<[doF32FTZ]>;
def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
(FMARELU_F16X2_FTZ Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>,
Requires<[doF32FTZ]>;

// NO FTZ
def : Pat<(f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
(FMARELU_F16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
def : Pat<(bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int16Regs:$a, Int16Regs:$b, Int16Regs:$c), fpimm_any_zero)),
(FMARELU_BF16 Int16Regs:$a, Int16Regs:$b, Int16Regs:$c)>;
def : Pat<(v2f16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2f16)),
(FMARELU_F16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
def : Pat<(v2bf16 (NVPTX_fmaxnum_nsz (NVPTX_fma_oneuse_and_nnan Int32Regs:$a, Int32Regs:$b, Int32Regs:$c), fpimm_positive_zero_v2bf16)),
(FMARELU_BF16X2 Int32Regs:$a, Int32Regs:$b, Int32Regs:$c)>;
Loading