@@ -1330,58 +1330,46 @@ def FDIV32ri_prec :
13301330// FMA
13311331//
13321332
1333- multiclass FMA<string OpcStr, RegisterClass RC, Operand ImmCls, Predicate Pred > {
1333+ multiclass FMA<string OpcStr, RegTyInfo t, list<Predicate> Preds = [] > {
13341334 defvar asmstr = OpcStr # " \t$dst, $a, $b, $c;";
1335- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1335+ def rrr : NVPTXInst<(outs t. RC:$dst), (ins t. RC:$a, t. RC:$b, t. RC:$c),
13361336 asmstr,
1337- [(set RC:$dst, (fma RC:$a, RC:$b, RC:$c))]>,
1338- Requires<[Pred]>;
1339- def rri : NVPTXInst<(outs RC:$dst),
1340- (ins RC:$a, RC:$b, ImmCls:$c),
1341- asmstr,
1342- [(set RC:$dst, (fma RC:$a, RC:$b, fpimm:$c))]>,
1343- Requires<[Pred]>;
1344- def rir : NVPTXInst<(outs RC:$dst),
1345- (ins RC:$a, ImmCls:$b, RC:$c),
1346- asmstr,
1347- [(set RC:$dst, (fma RC:$a, fpimm:$b, RC:$c))]>,
1348- Requires<[Pred]>;
1349- def rii : NVPTXInst<(outs RC:$dst),
1350- (ins RC:$a, ImmCls:$b, ImmCls:$c),
1351- asmstr,
1352- [(set RC:$dst, (fma RC:$a, fpimm:$b, fpimm:$c))]>,
1353- Requires<[Pred]>;
1354- def iir : NVPTXInst<(outs RC:$dst),
1355- (ins ImmCls:$a, ImmCls:$b, RC:$c),
1356- asmstr,
1357- [(set RC:$dst, (fma fpimm:$a, fpimm:$b, RC:$c))]>,
1358- Requires<[Pred]>;
1359-
1360- }
1361-
1362- multiclass FMA_F16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
1363- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1364- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1365- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
1366- Requires<[useFP16Math, Pred]>;
1367- }
1368-
1369- multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred> {
1370- def rrr : NVPTXInst<(outs RC:$dst), (ins RC:$a, RC:$b, RC:$c),
1371- !strconcat(OpcStr, " \t$dst, $a, $b, $c;"),
1372- [(set T:$dst, (fma T:$a, T:$b, T:$c))]>,
1373- Requires<[hasBF16Math, Pred]>;
1337+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, t.Ty:$c))]>,
1338+ Requires<Preds>;
1339+
1340+ if t.SupportsImm then {
1341+ def rri : NVPTXInst<(outs t.RC:$dst),
1342+ (ins t.RC:$a, t.RC:$b, t.Imm:$c),
1343+ asmstr,
1344+ [(set t.Ty:$dst, (fma t.Ty:$a, t.Ty:$b, fpimm:$c))]>,
1345+ Requires<Preds>;
1346+ def rir : NVPTXInst<(outs t.RC:$dst),
1347+ (ins t.RC:$a, t.Imm:$b, t.RC:$c),
1348+ asmstr,
1349+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, t.Ty:$c))]>,
1350+ Requires<Preds>;
1351+ def rii : NVPTXInst<(outs t.RC:$dst),
1352+ (ins t.RC:$a, t.Imm:$b, t.Imm:$c),
1353+ asmstr,
1354+ [(set t.Ty:$dst, (fma t.Ty:$a, fpimm:$b, fpimm:$c))]>,
1355+ Requires<Preds>;
1356+ def iir : NVPTXInst<(outs t.RC:$dst),
1357+ (ins t.Imm:$a, t.Imm:$b, t.RC:$c),
1358+ asmstr,
1359+ [(set t.Ty:$dst, (fma fpimm:$a, fpimm:$b, t.Ty:$c))]>,
1360+ Requires<Preds>;
1361+ }
13741362}
13751363
1376- defm FMA16_ftz : FMA_F16 <"fma.rn.ftz.f16", f16, Int16Regs , doF32FTZ>;
1377- defm FMA16 : FMA_F16 <"fma.rn.f16", f16, Int16Regs, True >;
1378- defm FMA16x2_ftz : FMA_F16 <"fma.rn.ftz.f16x2", v2f16, Int32Regs , doF32FTZ>;
1379- defm FMA16x2 : FMA_F16 <"fma.rn.f16x2", v2f16, Int32Regs, True >;
1380- defm BFMA16 : FMA_BF16 <"fma.rn.bf16", bf16, Int16Regs, True >;
1381- defm BFMA16x2 : FMA_BF16 <"fma.rn.bf16x2", v2bf16, Int32Regs, True >;
1382- defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
1383- defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True >;
1384- defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True >;
1364+ defm FMA16_ftz : FMA <"fma.rn.ftz.f16", F16RT, [useFP16Math , doF32FTZ] >;
1365+ defm FMA16 : FMA <"fma.rn.f16", F16RT, [useFP16Math] >;
1366+ defm FMA16x2_ftz : FMA <"fma.rn.ftz.f16x2", F16X2RT, [useFP16Math , doF32FTZ] >;
1367+ defm FMA16x2 : FMA <"fma.rn.f16x2", F16X2RT, [useFP16Math] >;
1368+ defm BFMA16 : FMA <"fma.rn.bf16", BF16RT, [hasBF16Math] >;
1369+ defm BFMA16x2 : FMA <"fma.rn.bf16x2", BF16X2RT, [hasBF16Math] >;
1370+ defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [ doF32FTZ] >;
1371+ defm FMA32 : FMA<"fma.rn.f32", F32RT >;
1372+ defm FMA64 : FMA<"fma.rn.f64", F64RT >;
13851373
13861374// sin/cos
13871375
@@ -1999,7 +1987,7 @@ multiclass FSET_FORMAT<PatFrag OpNode, PatLeaf Mode, PatLeaf ModeFTZ> {
19991987 Requires<[doF32FTZ]>;
20001988 def : Pat<(i1 (OpNode f32:$a, f32:$b)),
20011989 (SETP_f32rr $a, $b, Mode)>;
2002- def : Pat<(i1 (OpNode Float32Regs :$a, fpimm:$b)),
1990+ def : Pat<(i1 (OpNode f32 :$a, fpimm:$b)),
20031991 (SETP_f32ri $a, fpimm:$b, ModeFTZ)>,
20041992 Requires<[doF32FTZ]>;
20051993 def : Pat<(i1 (OpNode f32:$a, fpimm:$b)),
@@ -2056,7 +2044,7 @@ def SDTStoreParamProfile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
20562044def SDTStoreParamV2Profile : SDTypeProfile<0, 4, [SDTCisInt<0>, SDTCisInt<1>]>;
20572045def SDTStoreParamV4Profile : SDTypeProfile<0, 6, [SDTCisInt<0>, SDTCisInt<1>]>;
20582046def SDTStoreParam32Profile : SDTypeProfile<0, 3, [SDTCisInt<0>, SDTCisInt<1>]>;
2059- def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisInt<0 >]>;
2047+ def SDTCallArgProfile : SDTypeProfile<0, 2, [SDTCisVT<0, i32>, SDTCisVT<1, i32 >]>;
20602048def SDTCallArgMarkProfile : SDTypeProfile<0, 0, []>;
20612049def SDTCallVoidProfile : SDTypeProfile<0, 1, []>;
20622050def SDTCallValProfile : SDTypeProfile<1, 0, []>;
@@ -2352,42 +2340,10 @@ def CallArgEndInst1 : NVPTXInst<(outs), (ins), ");", [(CallArgEnd (i32 1))]>;
23522340def CallArgEndInst0 : NVPTXInst<(outs), (ins), ")", [(CallArgEnd (i32 0))]>;
23532341def RETURNInst : NVPTXInst<(outs), (ins), "ret;", [(RETURNNode)]>;
23542342
2355- class CallArgInst<NVPTXRegClass regclass> :
2356- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
2357- [(CallArg (i32 0), regclass:$a)]>;
2358-
2359- class CallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
2360- NVPTXInst<(outs), (ins regclass:$a), "$a, ",
2361- [(CallArg (i32 0), vt:$a)]>;
2362-
2363- class LastCallArgInst<NVPTXRegClass regclass> :
2364- NVPTXInst<(outs), (ins regclass:$a), "$a",
2365- [(LastCallArg (i32 0), regclass:$a)]>;
2366- class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
2367- NVPTXInst<(outs), (ins regclass:$a), "$a",
2368- [(LastCallArg (i32 0), vt:$a)]>;
2369-
2370- def CallArgI64 : CallArgInst<Int64Regs>;
2371- def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
2372- def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
2373- def CallArgF64 : CallArgInst<Float64Regs>;
2374- def CallArgF32 : CallArgInst<Float32Regs>;
2375-
2376- def LastCallArgI64 : LastCallArgInst<Int64Regs>;
2377- def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
2378- def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
2379- def LastCallArgF64 : LastCallArgInst<Float64Regs>;
2380- def LastCallArgF32 : LastCallArgInst<Float32Regs>;
2381-
2382- def CallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a, ",
2383- [(CallArg (i32 0), (i32 imm:$a))]>;
2384- def LastCallArgI32imm : NVPTXInst<(outs), (ins i32imm:$a), "$a",
2385- [(LastCallArg (i32 0), (i32 imm:$a))]>;
2386-
23872343def CallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a, ",
2388- [(CallArg (i32 1), (i32 imm:$a) )]>;
2344+ [(CallArg 1, imm:$a)]>;
23892345def LastCallArgParam : NVPTXInst<(outs), (ins i32imm:$a), "param$a",
2390- [(LastCallArg (i32 1), (i32 imm:$a) )]>;
2346+ [(LastCallArg 1, imm:$a)]>;
23912347
23922348def CallVoidInst : NVPTXInst<(outs), (ins ADDR_base:$addr), "$addr, ",
23932349 [(CallVoid (Wrapper tglobaladdr:$addr))]>;
0 commit comments