Skip to content

Commit 1d2bbcf

Browse files
[NVPTX] Eliminate prmts that result from BUILD_VECTOR of LoadV2
1 parent 22c519a commit 1d2bbcf

File tree

2 files changed

+93
-41
lines changed

2 files changed

+93
-41
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5772,7 +5772,8 @@ static SDValue PerformVSELECTCombine(SDNode *N,
57725772
}
57735773

57745774
static SDValue
5775-
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5775+
PerformBUILD_VECTOROfV2i16Combine(SDNode *N,
5776+
TargetLowering::DAGCombinerInfo &DCI) {
57765777
auto VT = N->getValueType(0);
57775778
if (!DCI.isAfterLegalizeDAG() ||
57785779
// only process v2*16 types
@@ -5833,6 +5834,80 @@ PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
58335834
return DAG.getBitcast(VT, PRMT);
58345835
}
58355836

5837+
static SDValue
5838+
PerformBUILD_VECTOROfTargetLoadCombine(SDNode *N,
5839+
TargetLowering::DAGCombinerInfo &DCI) {
5840+
// Match: BUILD_VECTOR of v4i8, where first two elements are from a
5841+
// NVPTXISD::LoadV2 or NVPTXISD::LDUV2 of i8, and the last two elements are
5842+
// zero constants. Replace with: zext the loaded i16 to i32, and return as a
5843+
// bitcast to v4i8.
5844+
EVT VT = N->getValueType(0);
5845+
if (VT != MVT::v4i8)
5846+
return SDValue();
5847+
// Check operands: [0]=lo, [1]=hi
5848+
SDValue Op0 = N->getOperand(0);
5849+
SDValue Op1 = N->getOperand(1);
5850+
// Check that Op0 and Op1 are from the same NVPTXISD::LoadV2 or
5851+
// NVPTXISD::LDUV2
5852+
if (Op0.getNode() != Op1.getNode())
5853+
return SDValue();
5854+
if (!(Op0.getOpcode() == NVPTXISD::LoadV2 ||
5855+
Op0.getOpcode() == NVPTXISD::LDUV2))
5856+
return SDValue();
5857+
if (Op0.getValueType() != MVT::i16)
5858+
return SDValue();
5859+
if (!(Op0.hasOneUse() && Op1.hasOneUse()))
5860+
return SDValue();
5861+
5862+
// Check operands: [2]= 0 or undef, [3]= 0 or undef
5863+
SDValue Op2 = N->getOperand(2);
5864+
SDValue Op3 = N->getOperand(3);
5865+
if (Op2 != Op3)
5866+
return SDValue();
5867+
if (!Op2.isUndef()) {
5868+
auto *C2 = dyn_cast<ConstantSDNode>(Op2);
5869+
if (!(C2 && C2->isZero()))
5870+
return SDValue();
5871+
}
5872+
5873+
// Now, replace with: zext(load i16) -> i32, then bitcast to v4i8
5874+
auto &DAG = DCI.DAG;
5875+
// Rebuild the load as i16
5876+
auto *Load = cast<MemSDNode>(Op0.getNode());
5877+
SDLoc DL(Load);
5878+
SDValue LoadI16;
5879+
if (Load->getOpcode() == NVPTXISD::LoadV2) {
5880+
LoadI16 = DAG.getLoad(MVT::i16, DL, Load->getChain(), Load->getBasePtr(),
5881+
Load->getPointerInfo(), Load->getAlign(),
5882+
Load->getMemOperand()->getFlags());
5883+
} else {
5884+
assert(Load->getOpcode() == NVPTXISD::LDUV2);
5885+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
5886+
SmallVector<SDValue, 4> Ops;
5887+
Ops.push_back(Load->getChain());
5888+
Ops.push_back(DAG.getConstant(Intrinsic::nvvm_ldu_global_i, DL,
5889+
TLI.getPointerTy(DAG.getDataLayout())));
5890+
for (unsigned i = 1; i < Load->getNumOperands(); ++i)
5891+
Ops.push_back(Load->getOperand(i));
5892+
SDVTList NodeVTList = DAG.getVTList(MVT::i16, MVT::Other);
5893+
LoadI16 = DAG.getMemIntrinsicNode(ISD::INTRINSIC_W_CHAIN, DL, NodeVTList,
5894+
Ops, MVT::i16, Load->getPointerInfo(),
5895+
Load->getAlign());
5896+
}
5897+
DAG.ReplaceAllUsesOfValueWith(SDValue(Load, 2), LoadI16.getValue(1));
5898+
SDValue Zext = DAG.getZExtOrTrunc(LoadI16, DL, MVT::i32);
5899+
return DAG.getBitcast(MVT::v4i8, Zext);
5900+
}
5901+
5902+
static SDValue
5903+
PerformBUILD_VECTORCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
5904+
if (const auto V = PerformBUILD_VECTOROfV2i16Combine(N, DCI))
5905+
return V;
5906+
if (const auto V = PerformBUILD_VECTOROfTargetLoadCombine(N, DCI))
5907+
return V;
5908+
return SDValue();
5909+
}
5910+
58365911
static SDValue combineADDRSPACECAST(SDNode *N,
58375912
TargetLowering::DAGCombinerInfo &DCI) {
58385913
auto *ASCN1 = cast<AddrSpaceCastSDNode>(N);

llvm/test/CodeGen/NVPTX/build-vector-combine.ll

Lines changed: 17 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,13 @@ target triple = "nvptx64-nvidia-cuda"
88
define void @t1() {
99
; CHECK-LABEL: t1(
1010
; CHECK: {
11-
; CHECK-NEXT: .reg .b16 %rs<3>;
12-
; CHECK-NEXT: .reg .b32 %r<5>;
11+
; CHECK-NEXT: .reg .b32 %r<2>;
1312
; CHECK-NEXT: .reg .b64 %rd<2>;
1413
; CHECK-EMPTY:
1514
; CHECK-NEXT: // %bb.0: // %entry
1615
; CHECK-NEXT: mov.b64 %rd1, 0;
17-
; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
18-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
19-
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
20-
; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
21-
; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
22-
; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r4, 0, 0, 0};
16+
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
17+
; CHECK-NEXT: st.global.v4.b32 [%rd1], {%r1, 0, 0, 0};
2318
; CHECK-NEXT: ret;
2419
entry:
2520
%0 = load <2 x i8>, ptr addrspace(1) null, align 4
@@ -33,18 +28,13 @@ entry:
3328
define void @t2() {
3429
; CHECK-LABEL: t2(
3530
; CHECK: {
36-
; CHECK-NEXT: .reg .b16 %rs<3>;
37-
; CHECK-NEXT: .reg .b32 %r<5>;
31+
; CHECK-NEXT: .reg .b32 %r<2>;
3832
; CHECK-NEXT: .reg .b64 %rd<2>;
3933
; CHECK-EMPTY:
4034
; CHECK-NEXT: // %bb.0: // %entry
4135
; CHECK-NEXT: mov.b64 %rd1, 0;
42-
; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
43-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
44-
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
45-
; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
46-
; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
47-
; CHECK-NEXT: st.local.b32 [%rd1], %r4;
36+
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
37+
; CHECK-NEXT: st.local.b32 [%rd1], %r1;
4838
; CHECK-NEXT: ret;
4939
entry:
5040
%0 = load <2 x i8>, ptr addrspace(1) null, align 8
@@ -58,19 +48,14 @@ declare <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
5848
define void @ldg(ptr addrspace(1) %ptr) {
5949
; CHECK-LABEL: ldg(
6050
; CHECK: {
61-
; CHECK-NEXT: .reg .b16 %rs<3>;
62-
; CHECK-NEXT: .reg .b32 %r<5>;
51+
; CHECK-NEXT: .reg .b32 %r<2>;
6352
; CHECK-NEXT: .reg .b64 %rd<3>;
6453
; CHECK-EMPTY:
6554
; CHECK-NEXT: // %bb.0: // %entry
6655
; CHECK-NEXT: ld.param.b64 %rd1, [ldg_param_0];
67-
; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
68-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
69-
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
70-
; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
71-
; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
56+
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
7257
; CHECK-NEXT: mov.b64 %rd2, 0;
73-
; CHECK-NEXT: st.local.b32 [%rd2], %r4;
58+
; CHECK-NEXT: st.local.b32 [%rd2], %r1;
7459
; CHECK-NEXT: ret;
7560
entry:
7661
%0 = tail call <2 x i8> @llvm.nvvm.ldg.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -84,19 +69,16 @@ declare <2 x i8> @llvm.nvvm.ldu.global.f.v2i8.p1(ptr addrspace(1) %ptr, i32 %ali
8469
define void @ldu(ptr addrspace(1) %ptr) {
8570
; CHECK-LABEL: ldu(
8671
; CHECK: {
87-
; CHECK-NEXT: .reg .b16 %rs<3>;
88-
; CHECK-NEXT: .reg .b32 %r<5>;
72+
; CHECK-NEXT: .reg .b16 %rs<2>;
73+
; CHECK-NEXT: .reg .b32 %r<2>;
8974
; CHECK-NEXT: .reg .b64 %rd<3>;
9075
; CHECK-EMPTY:
9176
; CHECK-NEXT: // %bb.0: // %entry
9277
; CHECK-NEXT: ld.param.b64 %rd1, [ldu_param_0];
93-
; CHECK-NEXT: ldu.global.v2.b8 {%rs1, %rs2}, [%rd1];
94-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
95-
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
96-
; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
97-
; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
78+
; CHECK-NEXT: ldu.global.b16 %rs1, [%rd1];
79+
; CHECK-NEXT: cvt.u32.u16 %r1, %rs1;
9880
; CHECK-NEXT: mov.b64 %rd2, 0;
99-
; CHECK-NEXT: st.local.b32 [%rd2], %r4;
81+
; CHECK-NEXT: st.local.b32 [%rd2], %r1;
10082
; CHECK-NEXT: ret;
10183
entry:
10284
%0 = tail call <2 x i8> @llvm.nvvm.ldu.global.i.v2i8.p1(ptr addrspace(1) %ptr, i32 2)
@@ -108,18 +90,13 @@ entry:
10890
define void @t3() {
10991
; CHECK-LABEL: t3(
11092
; CHECK: {
111-
; CHECK-NEXT: .reg .b16 %rs<3>;
112-
; CHECK-NEXT: .reg .b32 %r<5>;
93+
; CHECK-NEXT: .reg .b32 %r<2>;
11394
; CHECK-NEXT: .reg .b64 %rd<2>;
11495
; CHECK-EMPTY:
11596
; CHECK-NEXT: // %bb.0:
11697
; CHECK-NEXT: mov.b64 %rd1, 0;
117-
; CHECK-NEXT: ld.global.v2.b8 {%rs1, %rs2}, [%rd1];
118-
; CHECK-NEXT: cvt.u32.u16 %r1, %rs2;
119-
; CHECK-NEXT: cvt.u32.u16 %r2, %rs1;
120-
; CHECK-NEXT: prmt.b32 %r3, %r2, %r1, 0x3340U;
121-
; CHECK-NEXT: prmt.b32 %r4, %r3, 0, 0x5410U;
122-
; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r4, 0};
98+
; CHECK-NEXT: ld.global.b16 %r1, [%rd1];
99+
; CHECK-NEXT: st.global.v2.b32 [%rd1], {%r1, 0};
123100
; CHECK-NEXT: ret;
124101
%1 = load <2 x i8>, ptr addrspace(1) null, align 2
125102
%insval2 = bitcast <2 x i8> %1 to i16

0 commit comments

Comments
 (0)