@@ -238,6 +238,48 @@ def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
238238def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
239239
240240
241+ // This class provides a basic wrapper around an NVPTXInst that abstracts the
242+ // specific syntax of most PTX instructions. It automatically handles the
243+ // construction of the asm string based on the provided dag arguments.
244+ // For example, the following asm-strings would be computed:
245+ //
246+ // * BasicFlagsNVPTXInst<(outs Int32Regs:$dst),
247+ // (ins Int32Regs:$a, Int32Regs:$b), (ins),
248+ // "add.s32">;
249+ // ---> "add.s32 \t$dst, $a, $b;"
250+ //
251+ // * BasicFlagsNVPTXInst<(outs Int32Regs:$d),
252+ // (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
253+ // (ins PrmtMode:$mode),
254+ // "prmt.b32${mode}">;
255+ // ---> "prmt.b32${mode} \t$d, $a, $b, $c;"
256+ //
257+ class BasicFlagsNVPTXInst<dag outs_dag, dag ins_dag, dag flags_dag, string asmstr,
258+ list<dag> pattern = []>
259+ : NVPTXInst<
260+ outs_dag,
261+ !con(ins_dag, flags_dag),
262+ !strconcat(
263+ asmstr,
264+ !if(!and(!empty(ins_dag), !empty(outs_dag)), "",
265+ !strconcat(
266+ " \t",
267+ !interleave(
268+ !foreach(i, !range(!size(outs_dag)),
269+ "$" # !getdagname(outs_dag, i)),
270+ "|"),
271+ !if(!or(!empty(ins_dag), !empty(outs_dag)), "", ", "),
272+ !interleave(
273+ !foreach(i, !range(!size(ins_dag)),
274+ "$" # !getdagname(ins_dag, i)),
275+ ", "))),
276+ ";"),
277+ pattern>;
278+
279+ class BasicNVPTXInst<dag outs, dag insv, string asmstr, list<dag> pattern = []>
280+ : BasicFlagsNVPTXInst<outs, insv, (ins), asmstr, pattern>;
281+
282+
241283multiclass I3Inst<string op_str, SDPatternOperator op_node, RegTyInfo t,
242284 bit commutative, list<Predicate> requires = []> {
243285 defvar asmstr = op_str # " \t$dst, $a, $b;";
@@ -1581,24 +1623,6 @@ def Hexu32imm : Operand<i32> {
15811623 let PrintMethod = "printHexu32imm";
15821624}
15831625
1584- multiclass PRMT<ValueType T, RegisterClass RC> {
1585- def rrr
1586- : NVPTXInst<(outs RC:$d),
1587- (ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1588- !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1589- [(set T:$d, (prmt T:$a, T:$b, i32:$c, imm:$mode))]>;
1590- def rri
1591- : NVPTXInst<(outs RC:$d),
1592- (ins RC:$a, Int32Regs:$b, Hexu32imm:$c, PrmtMode:$mode),
1593- !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1594- [(set T:$d, (prmt T:$a, T:$b, imm:$c, imm:$mode))]>;
1595- def rii
1596- : NVPTXInst<(outs RC:$d),
1597- (ins RC:$a, i32imm:$b, Hexu32imm:$c, PrmtMode:$mode),
1598- !strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1599- [(set T:$d, (prmt T:$a, imm:$b, imm:$c, imm:$mode))]>;
1600- }
1601-
16021626let hasSideEffects = false in {
16031627 // order is somewhat important here. signed/unsigned variants match
16041628 // the same patterns, so the first one wins. Having unsigned byte extraction
@@ -1612,7 +1636,31 @@ let hasSideEffects = false in {
16121636 defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
16131637 defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
16141638
1615- defm PRMT_B32 : PRMT<i32, Int32Regs>;
1639+ def PRMT_B32rrr
1640+ : BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1641+ (ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
1642+ (ins PrmtMode:$mode),
1643+ "prmt.b32$mode",
1644+ [(set i32:$d, (prmt i32:$a, i32:$b, i32:$c, imm:$mode))]>;
1645+ def PRMT_B32rri
1646+ : BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1647+ (ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
1648+ (ins PrmtMode:$mode),
1649+ "prmt.b32$mode",
1650+ [(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
1651+ def PRMT_B32rii
1652+ : BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1653+ (ins Int32Regs:$a, i32imm:$b, Hexu32imm:$c),
1654+ (ins PrmtMode:$mode),
1655+ "prmt.b32$mode",
1656+ [(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
1657+ def PRMT_B32rir
1658+ : BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1659+ (ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
1660+ (ins PrmtMode:$mode),
1661+ "prmt.b32$mode",
1662+ [(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
1663+
16161664}
16171665
16181666
@@ -3265,25 +3313,25 @@ include "NVPTXIntrinsics.td"
32653313
32663314def : Pat <
32673315 (i32 (bswap i32:$a)),
3268- (INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>;
3316+ (PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE )>;
32693317
32703318def : Pat <
32713319 (v2i16 (bswap v2i16:$a)),
3272- (INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>;
3320+ (PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE )>;
32733321
32743322def : Pat <
32753323 (i64 (bswap i64:$a)),
32763324 (V2I32toI64
3277- (INT_NVVM_PRMT (I64toI32H_Sink $a), (i32 0), (i32 0x0123)),
3278- (INT_NVVM_PRMT (I64toI32L_Sink $a), (i32 0), (i32 0x0123)))>,
3325+ (PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE ),
3326+ (PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE ))>,
32793327 Requires<[hasPTX<71>]>;
32803328
32813329// Fall back to the old way if we don't have PTX 7.1.
32823330def : Pat <
32833331 (i64 (bswap i64:$a)),
32843332 (V2I32toI64
3285- (INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)),
3286- (INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>;
3333+ (PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE ),
3334+ (PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE ))>;
32873335
32883336
32893337////////////////////////////////////////////////////////////////////////////////
0 commit comments