Skip to content

Commit 81e91ea

Browse files
authored
[NVPTX] Use PRMT instruction to lower i16 bswap (#168968)
Previously, i16 `bswap` was lowered using multiple shift and OR operations. This patch adds a pattern to directly lower i16 `bswap` using the `PRMT` (permute) instruction, which is more efficient. Additionally, the lowering of `bswap` is moved into operation legalization, which allows for DAGCombiner to optimize the lowered code.
1 parent 8947ba0 commit 81e91ea

File tree

3 files changed

+77
-76
lines changed

3 files changed

+77
-76
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,6 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
713713
Custom);
714714
}
715715

716-
setOperationAction(ISD::BSWAP, MVT::i16, Expand);
717-
718716
setOperationAction(ISD::BR_JT, MVT::Other, Custom);
719717
setOperationAction(ISD::BRIND, MVT::Other, Expand);
720718

@@ -1106,6 +1104,10 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
11061104
// * MVT::Other - internal.addrspace.wrap
11071105
setOperationAction(ISD::INTRINSIC_WO_CHAIN,
11081106
{MVT::i32, MVT::i128, MVT::v4f32, MVT::Other}, Custom);
1107+
1108+
// Custom lowering for bswap
1109+
setOperationAction(ISD::BSWAP, {MVT::i16, MVT::i32, MVT::i64, MVT::v2i16},
1110+
Custom);
11091111
}
11101112

11111113
TargetLoweringBase::LegalizeTypeAction
@@ -2570,6 +2572,44 @@ static SDValue lowerTcgen05St(SDValue Op, SelectionDAG &DAG) {
25702572
return Tcgen05StNode;
25712573
}
25722574

2575+
static SDValue lowerBSWAP(SDValue Op, SelectionDAG &DAG) {
2576+
SDLoc DL(Op);
2577+
SDValue Src = Op.getOperand(0);
2578+
EVT VT = Op.getValueType();
2579+
2580+
switch (VT.getSimpleVT().SimpleTy) {
2581+
case MVT::i16: {
2582+
SDValue Extended = DAG.getNode(ISD::ANY_EXTEND, DL, MVT::i32, Src);
2583+
SDValue Swapped =
2584+
getPRMT(Extended, DAG.getConstant(0, DL, MVT::i32), 0x7701, DL, DAG);
2585+
return DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, Swapped);
2586+
}
2587+
case MVT::i32: {
2588+
return getPRMT(Src, DAG.getConstant(0, DL, MVT::i32), 0x0123, DL, DAG);
2589+
}
2590+
case MVT::v2i16: {
2591+
SDValue Converted = DAG.getBitcast(MVT::i32, Src);
2592+
SDValue Swapped =
2593+
getPRMT(Converted, DAG.getConstant(0, DL, MVT::i32), 0x2301, DL, DAG);
2594+
return DAG.getNode(ISD::BITCAST, DL, MVT::v2i16, Swapped);
2595+
}
2596+
case MVT::i64: {
2597+
SDValue UnpackSrc =
2598+
DAG.getNode(NVPTXISD::UNPACK_VECTOR, DL, {MVT::i32, MVT::i32}, Src);
2599+
SDValue SwappedLow =
2600+
getPRMT(UnpackSrc.getValue(0), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2601+
DL, DAG);
2602+
SDValue SwappedHigh =
2603+
getPRMT(UnpackSrc.getValue(1), DAG.getConstant(0, DL, MVT::i32), 0x0123,
2604+
DL, DAG);
2605+
return DAG.getNode(NVPTXISD::BUILD_VECTOR, DL, MVT::i64,
2606+
{SwappedHigh, SwappedLow});
2607+
}
2608+
default:
2609+
llvm_unreachable("unsupported type for bswap");
2610+
}
2611+
}
2612+
25732613
static unsigned getTcgen05MMADisableOutputLane(unsigned IID) {
25742614
switch (IID) {
25752615
case Intrinsic::nvvm_tcgen05_mma_shared_disable_output_lane_cg1:
@@ -3193,7 +3233,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31933233
return lowerCTLZCTPOP(Op, DAG);
31943234
case ISD::FREM:
31953235
return lowerFREM(Op, DAG);
3196-
3236+
case ISD::BSWAP:
3237+
return lowerBSWAP(Op, DAG);
31973238
default:
31983239
llvm_unreachable("Custom lowering not defined for operation");
31993240
}

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2468,38 +2468,6 @@ let Predicates = [hasPTX<73>, hasSM<52>] in {
24682468

24692469
include "NVPTXIntrinsics.td"
24702470

2471-
//-----------------------------------
2472-
// Notes
2473-
//-----------------------------------
2474-
// BSWAP is currently expanded. The following is a more efficient
2475-
// - for < sm_20, use vector scalar mov, as tesla support native 16-bit register
2476-
// - for sm_20, use pmpt (use vector scalar mov to get the pack and
2477-
// unpack). sm_20 supports native 32-bit register, but not native 16-bit
2478-
// register.
2479-
2480-
def : Pat <
2481-
(i32 (bswap i32:$a)),
2482-
(PRMT_B32rii $a, (i32 0), (i32 0x0123), PrmtNONE)>;
2483-
2484-
def : Pat <
2485-
(v2i16 (bswap v2i16:$a)),
2486-
(PRMT_B32rii $a, (i32 0), (i32 0x2301), PrmtNONE)>;
2487-
2488-
def : Pat <
2489-
(i64 (bswap i64:$a)),
2490-
(V2I32toI64
2491-
(PRMT_B32rii (I64toI32H_Sink $a), (i32 0), (i32 0x0123), PrmtNONE),
2492-
(PRMT_B32rii (I64toI32L_Sink $a), (i32 0), (i32 0x0123), PrmtNONE))>,
2493-
Requires<[hasPTX<71>]>;
2494-
2495-
// Fall back to the old way if we don't have PTX 7.1.
2496-
def : Pat <
2497-
(i64 (bswap i64:$a)),
2498-
(V2I32toI64
2499-
(PRMT_B32rii (I64toI32H $a), (i32 0), (i32 0x0123), PrmtNONE),
2500-
(PRMT_B32rii (I64toI32L $a), (i32 0), (i32 0x0123), PrmtNONE))>;
2501-
2502-
25032471
////////////////////////////////////////////////////////////////////////////////
25042472
// PTX Fence instructions
25052473
////////////////////////////////////////////////////////////////////////////////

llvm/test/CodeGen/NVPTX/bswap.ll

Lines changed: 33 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 4
2-
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -mattr=+ptx70 | FileCheck -check-prefixes CHECK,PTX70 %s
2+
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s
33
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify %}
4-
; RUN: %if ptxas-isa-7.0 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -mattr=+ptx70 | %ptxas-verify %}
5-
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 -mattr=+ptx71 | FileCheck -check-prefixes CHECK,PTX71 %s
6-
; RUN: %if ptxas-isa-7.1 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 -mattr=+ptx71 | %ptxas-verify %}
74

85
target triple = "nvptx64-nvidia-cuda"
96

107
define i16 @bswap16(i16 %a) {
118
; CHECK-LABEL: bswap16(
129
; CHECK: {
13-
; CHECK-NEXT: .reg .b16 %rs<5>;
14-
; CHECK-NEXT: .reg .b32 %r<2>;
10+
; CHECK-NEXT: .reg .b32 %r<3>;
1511
; CHECK-EMPTY:
1612
; CHECK-NEXT: // %bb.0:
17-
; CHECK-NEXT: ld.param.b16 %rs1, [bswap16_param_0];
18-
; CHECK-NEXT: shr.u16 %rs2, %rs1, 8;
19-
; CHECK-NEXT: shl.b16 %rs3, %rs1, 8;
20-
; CHECK-NEXT: or.b16 %rs4, %rs3, %rs2;
21-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs4;
22-
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
13+
; CHECK-NEXT: ld.param.b16 %r1, [bswap16_param_0];
14+
; CHECK-NEXT: prmt.b32 %r2, %r1, 0, 0x7701U;
15+
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
2316
; CHECK-NEXT: ret;
2417
%b = tail call i16 @llvm.bswap.i16(i16 %a)
2518
ret i16 %b
@@ -56,40 +49,39 @@ define <2 x i16> @bswapv2i16(<2 x i16> %a) #0 {
5649
}
5750

5851
define i64 @bswap64(i64 %a) {
59-
; PTX70-LABEL: bswap64(
60-
; PTX70: {
61-
; PTX70-NEXT: .reg .b32 %r<5>;
62-
; PTX70-NEXT: .reg .b64 %rd<3>;
63-
; PTX70-EMPTY:
64-
; PTX70-NEXT: // %bb.0:
65-
; PTX70-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
66-
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {%r1, tmp}, %rd1; }
67-
; PTX70-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
68-
; PTX70-NEXT: { .reg .b32 tmp; mov.b64 {tmp, %r3}, %rd1; }
69-
; PTX70-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
70-
; PTX70-NEXT: mov.b64 %rd2, {%r4, %r2};
71-
; PTX70-NEXT: st.param.b64 [func_retval0], %rd2;
72-
; PTX70-NEXT: ret;
73-
;
74-
; PTX71-LABEL: bswap64(
75-
; PTX71: {
76-
; PTX71-NEXT: .reg .b32 %r<5>;
77-
; PTX71-NEXT: .reg .b64 %rd<3>;
78-
; PTX71-EMPTY:
79-
; PTX71-NEXT: // %bb.0:
80-
; PTX71-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
81-
; PTX71-NEXT: mov.b64 {%r1, _}, %rd1;
82-
; PTX71-NEXT: prmt.b32 %r2, %r1, 0, 0x123U;
83-
; PTX71-NEXT: mov.b64 {_, %r3}, %rd1;
84-
; PTX71-NEXT: prmt.b32 %r4, %r3, 0, 0x123U;
85-
; PTX71-NEXT: mov.b64 %rd2, {%r4, %r2};
86-
; PTX71-NEXT: st.param.b64 [func_retval0], %rd2;
87-
; PTX71-NEXT: ret;
52+
; CHECK-LABEL: bswap64(
53+
; CHECK: {
54+
; CHECK-NEXT: .reg .b32 %r<5>;
55+
; CHECK-NEXT: .reg .b64 %rd<3>;
56+
; CHECK-EMPTY:
57+
; CHECK-NEXT: // %bb.0:
58+
; CHECK-NEXT: ld.param.b64 %rd1, [bswap64_param_0];
59+
; CHECK-NEXT: mov.b64 {%r1, %r2}, %rd1;
60+
; CHECK-NEXT: prmt.b32 %r3, %r1, 0, 0x123U;
61+
; CHECK-NEXT: prmt.b32 %r4, %r2, 0, 0x123U;
62+
; CHECK-NEXT: mov.b64 %rd2, {%r4, %r3};
63+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
64+
; CHECK-NEXT: ret;
8865
%b = tail call i64 @llvm.bswap.i64(i64 %a)
8966
ret i64 %b
9067
}
9168

69+
define <2 x i32> @bswapv2i32(<2 x i32> %a) {
70+
; CHECK-LABEL: bswapv2i32(
71+
; CHECK: {
72+
; CHECK-NEXT: .reg .b32 %r<5>;
73+
; CHECK-EMPTY:
74+
; CHECK-NEXT: // %bb.0:
75+
; CHECK-NEXT: ld.param.v2.b32 {%r1, %r2}, [bswapv2i32_param_0];
76+
; CHECK-NEXT: prmt.b32 %r3, %r2, 0, 0x123U;
77+
; CHECK-NEXT: prmt.b32 %r4, %r1, 0, 0x123U;
78+
; CHECK-NEXT: st.param.v2.b32 [func_retval0], {%r4, %r3};
79+
; CHECK-NEXT: ret;
80+
%b = tail call <2 x i32> @llvm.bswap.v2i32(<2 x i32> %a)
81+
ret <2 x i32> %b
82+
}
9283
declare i16 @llvm.bswap.i16(i16)
9384
declare i32 @llvm.bswap.i32(i32)
9485
declare <2 x i16> @llvm.bswap.v2i16(<2 x i16>)
9586
declare i64 @llvm.bswap.i64(i64)
87+
declare <2 x i32> @llvm.bswap.v2i32(<2 x i32>)

0 commit comments

Comments
 (0)