diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp index 0551954444e57..67dc7904a91ae 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.cpp @@ -209,9 +209,7 @@ bool NVPTXInstrInfo::isSchedulingBoundary(const MachineInstr &MI, switch (MI.getOpcode()) { case NVPTX::CallUniPrintCallRetInst1: case NVPTX::CallArgBeginInst: - case NVPTX::CallArgI32imm: case NVPTX::CallArgParam: - case NVPTX::LastCallArgI32imm: case NVPTX::LastCallArgParam: case NVPTX::CallArgEndInst1: return true; diff --git a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td index 043da14bcb236..11d77599d4ac3 100644 --- a/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td +++ b/llvm/lib/Target/NVPTX/NVPTXInstrInfo.td @@ -1330,58 +1330,46 @@ def FDIV32ri_prec : // FMA // -multiclass FMA { +multiclass FMA Preds = []> { defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;"; - def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), + def rrr : NVPTXInst<(outs t.RC:$dst), (ins t.RC:$a, t.RC:$b, t.RC:$c), asmstr, - [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>, - Requires<[Pred]>; - def rri : NVPTXInst<(outs RC:$dst), - (ins RC:$a, RC:$b, ImmCls:$c), - asmstr, - [(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>, - Requires<[Pred]>; - def rir : NVPTXInst<(outs RC:$dst), - (ins RC:$a, ImmCls:$b, RC:$c), - asmstr, - [(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>, - Requires<[Pred]>; - def rii : NVPTXInst<(outs RC:$dst), - (ins RC:$a, ImmCls:$b, ImmCls:$c), - asmstr, - [(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>, - Requires<[Pred]>; - def iir : NVPTXInst<(outs RC:$dst), - (ins ImmCls:$a, ImmCls:$b, RC:$c), - asmstr, - [(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>, - Requires<[Pred]>; - -} - -multiclass FMA_F16 { - def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), - !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set T:$dst, (fma T:$a, T:$b, T:$c))]>, - Requires<[useFP16Math, Pred]>; -} - -multiclass FMA_BF16 { - def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c), - !strconcat(OpcStr, " \t$dst, $a, $b, $c;"), - [(set T:$dst, (fma T:$a, T:$b, T:$c))]>, - Requires<[hasBF16Math, Pred]>; + [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>, + Requires; + + if t.SupportsImm then { + def rri : NVPTXInst<(outs t.RC:$dst), + (ins t.RC:$a, t.RC:$b, t.Imm:$c), + asmstr, + [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>, + Requires; + def rir : NVPTXInst<(outs t.RC:$dst), + (ins t.RC:$a, t.Imm:$b, t.RC:$c), + asmstr, + [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>, + Requires; + def rii : NVPTXInst<(outs t.RC:$dst), + (ins t.RC:$a, t.Imm:$b, t.Imm:$c), + asmstr, + [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>, + Requires; + def iir : NVPTXInst<(outs t.RC:$dst), + (ins t.Imm:$a, t.Imm:$b, t.RC:$c), + asmstr, + [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>, + Requires; + } } -defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>; -defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>; -defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>; -defm FMA16x2 : FMA_F16<"fma.rn.f16x2", v2f16, Int32Regs, True>; -defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>; -defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>; -defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>; -defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>; -defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>; +defm FMA16_ftz : FMA<"fma.rn.ftz.f16", F16RT, [useFP16Math, doF32FTZ]>; +defm FMA16 : FMA<"fma.rn.f16", F16RT, [useFP16Math]>; +defm FMA16x2_ftz : FMA<"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math, doF32FTZ]>; +defm FMA16x2 : FMA<"fma.rn.f16x2", F16X2RT, [useFP16Math]>; +defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>; +defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>; +defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>; +defm FMA32 : FMA<"fma.rn.f32", F32RT>; +defm FMA64 : FMA<"fma.rn.f64", F64RT>; // sin/cos @@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT { Requires<[doF32FTZ]>; def : Pat<(i1 (OpNode f32:$a, f32:$b)), (SETP_f32rr $a, $b, Mode)>; - def : Pat<(i1 (OpNode Float32Regs:$a, fpimm:$b)), + def : Pat<(i1 (OpNode f32:$a, fpimm:$b)), (SETP_f32ri $a, fpimm:$b, ModeFTZ)>, Requires<[doF32FTZ]>; def : Pat<(i1 (OpNode f32:$a, fpimm:$b)), @@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>; def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>; -def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0>]>; +def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32>]>; def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>; def SDTCallVoidProfile : SDTypeProfile<0, 1, []>; def SDTCallValProfile : SDTypeProfile<1, 0, []>; @@ -2352,42 +2340,10 @@ def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>; def CallArgEndInst0 : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>; def RETURNInst : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>; -class CallArgInst : - NVPTXInst<(outs), (ins regclass:$a), "$a, ", - [(CallArg (i32 0), regclass:$a)]>; - -class CallArgInstVT : - NVPTXInst<(outs), (ins regclass:$a), "$a, ", - [(CallArg (i32 0), vt:$a)]>; - -class LastCallArgInst : - NVPTXInst<(outs), (ins regclass:$a), "$a", - [(LastCallArg (i32 0), regclass:$a)]>; -class LastCallArgInstVT : - NVPTXInst<(outs), (ins regclass:$a), "$a", - [(LastCallArg (i32 0), vt:$a)]>; - -def CallArgI64 : CallArgInst; -def CallArgI32 : CallArgInstVT; -def CallArgI16 : CallArgInstVT; -def CallArgF64 : CallArgInst; -def CallArgF32 : CallArgInst; - -def LastCallArgI64 : LastCallArgInst; -def LastCallArgI32 : LastCallArgInstVT; -def LastCallArgI16 : LastCallArgInstVT; -def LastCallArgF64 : LastCallArgInst; -def LastCallArgF32 : LastCallArgInst; - -def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ", - [(CallArg (i32 0), (i32 imm:$a))]>; -def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a", - [(LastCallArg (i32 0), (i32 imm:$a))]>; - def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ", - [(CallArg (i32 1), (i32 imm:$a))]>; + [(CallArg 1, imm:$a)]>; def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a", - [(LastCallArg (i32 1), (i32 imm:$a))]>; + [(LastCallArg 1, imm:$a)]>; def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ", [(CallVoid (Wrapper tglobaladdr:$addr))]>;