Skip to content

Commit 778fdb1

Browse files
committed
[NVPTX] Add support for specialized prmt variants
1 parent e72d8b2 commit 778fdb1

File tree

6 files changed

+326
-48
lines changed

6 files changed

+326
-48
lines changed

llvm/docs/NVPTXUsage.rst

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -624,6 +624,126 @@ all bits set to 0 except for %b bits starting at bit position %a. For the
624624
'``clamp``' variants, the values of %a and %b are clamped to the range [0, 32],
625625
which in practice is equivalent to using them as is.
626626

627+
'``llvm.nvvm.prmt``' Intrinsic
628+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
629+
630+
Syntax:
631+
"""""""
632+
633+
.. code-block:: llvm
634+
635+
declare i32 @llvm.nvvm.prmt(i32 %a, i32 %b, i32 %c)
636+
637+
Overview:
638+
"""""""""
639+
640+
The '``llvm.nvvm.prmt``' constructs a permutation of the bytes of the first two
641+
operands, selecting based on the third operand.
642+
643+
Semantics:
644+
""""""""""
645+
646+
The bytes in the first two source operands are numbered from 0 to 7:
647+
{%b, %a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each byte in the target
648+
register, a 4-bit selection value is defined.
649+
650+
The 3 lsbs of the selection value specify which of the 8 source bytes should be
651+
moved into the target position. The msb defines if the byte value should be
652+
copied, or if the sign (msb of the byte) should be replicated over all 8 bits
653+
of the target position (sign extend of the byte value); msb=0 means copy the
654+
literal value; msb=1 means replicate the sign.
655+
656+
These 4-bit selection values are pulled from the lower 16-bits of the third
657+
operand, with the least significant selection value corresponding to the least
658+
significant byte of the destination.
659+
660+
661+
'``llvm.nvvm.prmt.*``' Intrinsics
662+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
663+
664+
Syntax:
665+
"""""""
666+
667+
.. code-block:: llvm
668+
669+
declare i32 @llvm.nvvm.prmt.f4e(i32 %a, i32 %b, i32 %c)
670+
declare i32 @llvm.nvvm.prmt.b4e(i32 %a, i32 %b, i32 %c)
671+
672+
declare i32 @llvm.nvvm.prmt.rc8(i32 %a, i32 %c)
673+
declare i32 @llvm.nvvm.prmt.ecl(i32 %a, i32 %c)
674+
declare i32 @llvm.nvvm.prmt.ecr(i32 %a, i32 %c)
675+
declare i32 @llvm.nvvm.prmt.rc16(i32 %a, i32 %c)
676+
677+
Overview:
678+
"""""""""
679+
680+
The '``llvm.nvvm.prmt.*``' family of intrinsics constructs a permutation of the
681+
bytes of the first one or two operands, selecting based on the 2 least
682+
significant bits of the final operand.
683+
684+
Semantics:
685+
""""""""""
686+
687+
As with the generic '``llvm.nvvm.prmt``' intrinsic, the bytes in the first one
688+
or two source operands are numbered. The first source operand (%a) is numbered
689+
{b3, b2, b1, b0}, in the case of the '``f4e``' and '``b4e``' variants, the
690+
second source operand (%b) is numbered {b7, b6, b5, b4}.
691+
692+
Depending on the 2 least significant bits of the final operand, the result of
693+
the permutation is defined as follows:
694+
695+
+------------+---------+--------------+
696+
| Mode | %c[1:0] | Output |
697+
+------------+---------+--------------+
698+
| '``f4e``' | 0 | {3, 2, 1, 0} |
699+
| +---------+--------------+
700+
| | 1 | {4, 3, 2, 1} |
701+
| +---------+--------------+
702+
| | 2 | {5, 4, 3, 2} |
703+
| +---------+--------------+
704+
| | 3 | {6, 5, 4, 3} |
705+
+------------+---------+--------------+
706+
| '``b4e``' | 0 | {5, 6, 7, 0} |
707+
| +---------+--------------+
708+
| | 1 | {6, 7, 0, 1} |
709+
| +---------+--------------+
710+
| | 2 | {7, 0, 1, 2} |
711+
| +---------+--------------+
712+
| | 3 | {0, 1, 2, 3} |
713+
+------------+---------+--------------+
714+
| '``rc8``' | 0 | {0, 0, 0, 0} |
715+
| +---------+--------------+
716+
| | 1 | {1, 1, 1, 1} |
717+
| +---------+--------------+
718+
| | 2 | {2, 2, 2, 2} |
719+
| +---------+--------------+
720+
| | 3 | {3, 3, 3, 3} |
721+
+------------+---------+--------------+
722+
| '``ecl``' | 0 | {3, 2, 1, 0} |
723+
| +---------+--------------+
724+
| | 1 | {3, 2, 1, 1} |
725+
| +---------+--------------+
726+
| | 2 | {3, 2, 2, 2} |
727+
| +---------+--------------+
728+
| | 3 | {3, 3, 3, 3} |
729+
+------------+---------+--------------+
730+
| '``ecr``' | 0 | {0, 0, 0, 0} |
731+
| +---------+--------------+
732+
| | 1 | {1, 1, 1, 0} |
733+
| +---------+--------------+
734+
| | 2 | {2, 2, 1, 0} |
735+
| +---------+--------------+
736+
| | 3 | {3, 2, 1, 0} |
737+
+------------+---------+--------------+
738+
| '``rc16``' | 0 | {1, 0, 1, 0} |
739+
| +---------+--------------+
740+
| | 1 | {3, 2, 3, 2} |
741+
| +---------+--------------+
742+
| | 2 | {1, 0, 1, 0} |
743+
| +---------+--------------+
744+
| | 3 | {3, 2, 3, 2} |
745+
+------------+---------+--------------+
746+
627747
TMA family of Intrinsics
628748
------------------------
629749

llvm/include/llvm/IR/IntrinsicsNVVM.td

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -739,9 +739,23 @@ class NVVMBuiltin :
739739
}
740740

741741
let TargetPrefix = "nvvm" in {
742-
def int_nvvm_prmt : NVVMBuiltin,
743-
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty],
744-
[IntrNoMem, IntrSpeculatable]>;
742+
743+
// PRMT - permute
744+
745+
let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
746+
def int_nvvm_prmt : NVVMBuiltin,
747+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
748+
749+
foreach mode = ["f4e", "b4e"] in
750+
def int_nvvm_prmt_ # mode :
751+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty]>;
752+
753+
// Note: these variants also have 2 source operands but only one will ever
754+
// be used so we eliminate the other operand in the IR.
755+
foreach mode = ["rc8", "ecl", "ecr", "rc16"] in
756+
def int_nvvm_prmt_ # mode :
757+
DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty]>;
758+
}
745759

746760
def int_nvvm_nanosleep : NVVMBuiltin,
747761
DefaultAttrsIntrinsic<[], [llvm_i32_ty],

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,24 +1581,6 @@ def Hexu32imm : Operand<i32> {
15811581
let PrintMethod = "printHexu32imm";
15821582
}
15831583

1584-
multiclass PRMT<ValueType T, RegisterClass RC> {
1585-
def rrr
1586-
: NVPTXInst<(outs RC:$d),
1587-
(ins RC:$a, Int32Regs:$b, Int32Regs:$c, PrmtMode:$mode),
1588-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1589-
[(set T:$d, (prmt T:$a, T:$b, i32:$c, imm:$mode))]>;
1590-
def rri
1591-
: NVPTXInst<(outs RC:$d),
1592-
(ins RC:$a, Int32Regs:$b, Hexu32imm:$c, PrmtMode:$mode),
1593-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1594-
[(set T:$d, (prmt T:$a, T:$b, imm:$c, imm:$mode))]>;
1595-
def rii
1596-
: NVPTXInst<(outs RC:$d),
1597-
(ins RC:$a, i32imm:$b, Hexu32imm:$c, PrmtMode:$mode),
1598-
!strconcat("prmt.b32${mode}", " \t$d, $a, $b, $c;"),
1599-
[(set T:$d, (prmt T:$a, imm:$b, imm:$c, imm:$mode))]>;
1600-
}
1601-
16021584
let hasSideEffects = false in {
16031585
// order is somewhat important here. signed/unsigned variants match
16041586
// the same patterns, so the first one wins. Having unsigned byte extraction
@@ -1612,7 +1594,31 @@ let hasSideEffects = false in {
16121594
defm BFI_B32 : BFI<"bfi.b32", i32, Int32Regs, i32imm>;
16131595
defm BFI_B64 : BFI<"bfi.b64", i64, Int64Regs, i64imm>;
16141596

1615-
defm PRMT_B32 : PRMT<i32, Int32Regs>;
1597+
def PRMT_B32rrr
1598+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1599+
(ins Int32Regs:$a, Int32Regs:$b, Int32Regs:$c),
1600+
(ins PrmtMode:$mode),
1601+
"prmt.b32$mode",
1602+
[(set i32:$d, (prmt i32:$a, i32:$b, i32:$c, imm:$mode))]>;
1603+
def PRMT_B32rri
1604+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1605+
(ins Int32Regs:$a, Int32Regs:$b, Hexu32imm:$c),
1606+
(ins PrmtMode:$mode),
1607+
"prmt.b32$mode",
1608+
[(set i32:$d, (prmt i32:$a, i32:$b, imm:$c, imm:$mode))]>;
1609+
def PRMT_B32rii
1610+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1611+
(ins Int32Regs:$a, i32imm:$b, Hexu32imm:$c),
1612+
(ins PrmtMode:$mode),
1613+
"prmt.b32$mode",
1614+
[(set i32:$d, (prmt i32:$a, imm:$b, imm:$c, imm:$mode))]>;
1615+
def PRMT_B32rir
1616+
: BasicFlagsNVPTXInst<(outs Int32Regs:$d),
1617+
(ins Int32Regs:$a, i32imm:$b, Int32Regs:$c),
1618+
(ins PrmtMode:$mode),
1619+
"prmt.b32$mode",
1620+
[(set i32:$d, (prmt i32:$a, imm:$b, i32:$c, imm:$mode))]>;
1621+
16161622
}
16171623

16181624

@@ -3265,25 +3271,25 @@ include "NVPTXIntrinsics.td"
32653271

32663272
def : Pat <
32673273
(i32 (bswap i32:$a)),
3268-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x0123))>;
3274+
(PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE)>;
32693275

32703276
def : Pat <
32713277
(v2i16 (bswap v2i16:$a)),
3272-
(INT_NVVM_PRMT $a, (i32 0), (i32 0x2301))>;
3278+
(PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE)>;
32733279

32743280
def : Pat <
32753281
(i64 (bswap i64:$a)),
32763282
(V2I32toI64
3277-
(INT_NVVM_PRMT (I64toI32H_Sink $a), (i32 0), (i32 0x0123)),
3278-
(INT_NVVM_PRMT (I64toI32L_Sink $a), (i32 0), (i32 0x0123)))>,
3283+
(PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE),
3284+
(PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE))>,
32793285
Requires<[hasPTX<71>]>;
32803286

32813287
// Fall back to the old way if we don't have PTX 7.1.
32823288
def : Pat <
32833289
(i64 (bswap i64:$a)),
32843290
(V2I32toI64
3285-
(INT_NVVM_PRMT (I64toI32H $a), (i32 0), (i32 0x0123)),
3286-
(INT_NVVM_PRMT (I64toI32L $a), (i32 0), (i32 0x0123)))>;
3291+
(PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE),
3292+
(PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE))>;
32873293

32883294

32893295
////////////////////////////////////////////////////////////////////////////////

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,8 +1025,23 @@ class F_MATH_3<string OpcStr, NVPTXRegClass t_regclass,
10251025
// MISC
10261026
//
10271027

1028-
def INT_NVVM_PRMT : F_MATH_3<"prmt.b32 \t$dst, $src0, $src1, $src2;", Int32Regs,
1029-
Int32Regs, Int32Regs, Int32Regs, int_nvvm_prmt>;
1028+
class PRMT3Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1029+
: Pat<(prmt_intrinsic i32:$a, i32:$b, i32:$c),
1030+
(PRMT_B32rrr $a, $b, $c, prmt_mode)>;
1031+
1032+
class PRMT2Pat<Intrinsic prmt_intrinsic, PatLeaf prmt_mode>
1033+
: Pat<(prmt_intrinsic i32:$a, i32:$c),
1034+
(PRMT_B32rir $a, (i32 0), $c, prmt_mode)>;
1035+
1036+
def : PRMT3Pat<int_nvvm_prmt, PrmtNONE>;
1037+
def : PRMT3Pat<int_nvvm_prmt_f4e, PrmtF4E>;
1038+
def : PRMT3Pat<int_nvvm_prmt_b4e, PrmtB4E>;
1039+
1040+
def : PRMT2Pat<int_nvvm_prmt_rc8, PrmtRC8>;
1041+
def : PRMT2Pat<int_nvvm_prmt_ecl, PrmtECL>;
1042+
def : PRMT2Pat<int_nvvm_prmt_ecr, PrmtECR>;
1043+
def : PRMT2Pat<int_nvvm_prmt_rc16, PrmtRC16>;
1044+
10301045

10311046
def INT_NVVM_NANOSLEEP_I : NVPTXInst<(outs), (ins i32imm:$i), "nanosleep.u32 \t$i;",
10321047
[(int_nvvm_nanosleep imm:$i)]>,

llvm/test/CodeGen/NVPTX/bswap.ll

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ define i32 @bswap32(i32 %a) {
3333
; CHECK-EMPTY:
3434
; CHECK-NEXT: // %bb.0:
3535
; CHECK-NEXT: ld.param.b32 %r1, [bswap32_param_0];
36-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 291;
36+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
3737
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
3838
; CHECK-NEXT: ret;
3939
%b = tail call i32 @llvm.bswap.i32(i32 %a)
@@ -48,33 +48,43 @@ define <2 x i16> @bswapv2i16(<2 x i16> %a) #0 {
4848
; CHECK-EMPTY:
4949
; CHECK-NEXT: // %bb.0:
5050
; CHECK-NEXT: ld.param.b32 %r1, [bswapv2i16_param_0];
51-
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 8961;
51+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x2301U;
5252
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
5353
; CHECK-NEXT: ret;
5454
%b = tail call <2 x i16> @llvm.bswap.v2i16(<2 x i16> %a)
5555
ret <2 x i16> %b
5656
}
5757

5858
define i64 @bswap64(i64 %a) {
59-
; CHECK-LABEL: bswap64(
60-
; CHECK: {
61-
; CHECK-NEXT: .reg .b32 %r<5>;
62-
; CHECK-NEXT: .reg .b64 %rd<3>;
63-
; CHECK-EMPTY:
64-
; CHECK-NEXT: // %bb.0:
65-
; CHECK-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
59+
; PTX70-LABEL: bswap64(
60+
; PTX70: {
61+
; PTX70-NEXT: .reg .b32 %r<5>;
62+
; PTX70-NEXT: .reg .b64 %rd<3>;
63+
; PTX70-EMPTY:
64+
; PTX70-NEXT: // %bb.0:
65+
; PTX70-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
6666
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {%r1, tmp}, %rd1; }
67-
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 291;
67+
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
6868
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd1; }
69-
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 291;
69+
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
7070
; PTX70-NEXT: mov.b64 %rd2, {%r4, %r2};
71-
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
72-
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 291;
73-
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
74-
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 291;
75-
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
76-
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
77-
; CHECK-NEXT: ret;
71+
; PTX70-NEXT: st.param.b64 [func_retval0], %rd2;
72+
; PTX70-NEXT: ret;
73+
;
74+
; PTX71-LABEL: bswap64(
75+
; PTX71: {
76+
; PTX71-NEXT: .reg .b32 %r<5>;
77+
; PTX71-NEXT: .reg .b64 %rd<3>;
78+
; PTX71-EMPTY:
79+
; PTX71-NEXT: // %bb.0:
80+
; PTX71-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
81+
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
82+
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
83+
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
84+
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
85+
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
86+
; PTX71-NEXT: st.param.b64 [func_retval0], %rd2;
87+
; PTX71-NEXT: ret;
7888
%b = tail call i64 @llvm.bswap.i64(i64 %a)
7989
ret i64 %b
8090
}

0 commit comments

Comments
 (0)