Skip to content

Commit 691ccf2

Browse files
authored
[NVPTX] Implement computeKnownBitsForTargetNode for LoadV (llvm#154165)
Remove AND combines as they are no longer needed after this.
1 parent 8d7b50e commit 691ccf2

File tree

5 files changed

+73
-111
lines changed

5 files changed

+73
-111
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 15 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,15 +1150,12 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11501150
return true;
11511151
}
11521152

1153-
static unsigned getLoadStoreVectorNumElts(SDNode *N) {
1153+
static unsigned getStoreVectorNumElts(SDNode *N) {
11541154
switch (N->getOpcode()) {
1155-
case NVPTXISD::LoadV2:
11561155
case NVPTXISD::StoreV2:
11571156
return 2;
1158-
case NVPTXISD::LoadV4:
11591157
case NVPTXISD::StoreV4:
11601158
return 4;
1161-
case NVPTXISD::LoadV8:
11621159
case NVPTXISD::StoreV8:
11631160
return 8;
11641161
default:
@@ -1171,7 +1168,6 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11711168
const EVT MemEVT = LD->getMemoryVT();
11721169
if (!MemEVT.isSimple())
11731170
return false;
1174-
const MVT MemVT = MemEVT.getSimpleVT();
11751171

11761172
// Address Space Setting
11771173
const auto CodeAddrSpace = getAddrSpace(LD);
@@ -1191,18 +1187,15 @@ bool NVPTXDAGToDAGISel::tryLoadVector(SDNode *N) {
11911187
// Float : ISD::NON_EXTLOAD or ISD::EXTLOAD and the type is float
11921188
// Read at least 8 bits (predicates are stored as 8-bit values)
11931189
// The last operand holds the original LoadSDNode::getExtensionType() value
1194-
const unsigned TotalWidth = MemVT.getSizeInBits();
11951190
const unsigned ExtensionType =
11961191
N->getConstantOperandVal(N->getNumOperands() - 1);
11971192
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
11981193
? NVPTX::PTXLdStInstCode::Signed
11991194
: NVPTX::PTXLdStInstCode::Untyped;
12001195

1201-
const unsigned FromTypeWidth = TotalWidth / getLoadStoreVectorNumElts(N);
1196+
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
12021197

12031198
assert(!(EltVT.isVector() && ExtensionType != ISD::NON_EXTLOAD));
1204-
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1205-
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
12061199

12071200
const auto [Base, Offset] = selectADDR(N->getOperand(1), CurDAG);
12081201
SDValue Ops[] = {getI32Imm(Ordering, DL),
@@ -1247,30 +1240,23 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
12471240
const EVT LoadedEVT = LD->getMemoryVT();
12481241
if (!LoadedEVT.isSimple())
12491242
return false;
1250-
const MVT LoadedVT = LoadedEVT.getSimpleVT();
12511243

12521244
SDLoc DL(LD);
12531245

1254-
const unsigned TotalWidth = LoadedVT.getSizeInBits();
12551246
unsigned ExtensionType;
1256-
unsigned NumElts;
12571247
if (const auto *Load = dyn_cast<LoadSDNode>(LD)) {
12581248
ExtensionType = Load->getExtensionType();
1259-
NumElts = 1;
12601249
} else {
12611250
ExtensionType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
1262-
NumElts = getLoadStoreVectorNumElts(LD);
12631251
}
12641252
const unsigned FromType = (ExtensionType == ISD::SEXTLOAD)
12651253
? NVPTX::PTXLdStInstCode::Signed
12661254
: NVPTX::PTXLdStInstCode::Untyped;
12671255

1268-
const unsigned FromTypeWidth = TotalWidth / NumElts;
1256+
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
12691257

12701258
assert(!(LD->getSimpleValueType(0).isVector() &&
12711259
ExtensionType != ISD::NON_EXTLOAD));
1272-
assert(isPowerOf2_32(FromTypeWidth) && FromTypeWidth >= 8 &&
1273-
FromTypeWidth <= 128 && TotalWidth <= 256 && "Invalid width for load");
12741260

12751261
const auto [Base, Offset] = selectADDR(LD->getOperand(1), CurDAG);
12761262
SDValue Ops[] = {getI32Imm(FromType, DL), getI32Imm(FromTypeWidth, DL), Base,
@@ -1309,26 +1295,21 @@ bool NVPTXDAGToDAGISel::tryLDG(MemSDNode *LD) {
13091295
return true;
13101296
}
13111297

1298+
unsigned NVPTXDAGToDAGISel::getFromTypeWidthForLoad(const MemSDNode *Mem) {
1299+
auto TotalWidth = Mem->getMemoryVT().getSizeInBits();
1300+
auto NumElts = Mem->getNumValues() - 1;
1301+
auto ElementBitWidth = TotalWidth / NumElts;
1302+
assert(isPowerOf2_32(ElementBitWidth) && ElementBitWidth >= 8 &&
1303+
ElementBitWidth <= 128 && TotalWidth <= 256 &&
1304+
"Invalid width for load");
1305+
return ElementBitWidth;
1306+
}
1307+
13121308
bool NVPTXDAGToDAGISel::tryLDU(SDNode *N) {
13131309
auto *LD = cast<MemSDNode>(N);
13141310

1315-
unsigned NumElts;
1316-
switch (N->getOpcode()) {
1317-
default:
1318-
llvm_unreachable("Unexpected opcode");
1319-
case ISD::INTRINSIC_W_CHAIN:
1320-
NumElts = 1;
1321-
break;
1322-
case NVPTXISD::LDUV2:
1323-
NumElts = 2;
1324-
break;
1325-
case NVPTXISD::LDUV4:
1326-
NumElts = 4;
1327-
break;
1328-
}
1329-
13301311
SDLoc DL(N);
1331-
const unsigned FromTypeWidth = LD->getMemoryVT().getSizeInBits() / NumElts;
1312+
const unsigned FromTypeWidth = getFromTypeWidthForLoad(LD);
13321313
const MVT::SimpleValueType TargetVT = LD->getSimpleValueType(0).SimpleTy;
13331314

13341315
// If this is an LDU intrinsic, the address is the third operand. If its an
@@ -1443,7 +1424,7 @@ bool NVPTXDAGToDAGISel::tryStoreVector(SDNode *N) {
14431424
// - for integer type, always use 'u'
14441425
const unsigned TotalWidth = StoreVT.getSimpleVT().getSizeInBits();
14451426

1446-
const unsigned NumElts = getLoadStoreVectorNumElts(ST);
1427+
const unsigned NumElts = getStoreVectorNumElts(ST);
14471428

14481429
SmallVector<SDValue, 16> Ops;
14491430
for (auto &V : ST->ops().slice(1, NumElts))

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
111111

112112
public:
113113
static NVPTX::AddressSpace getAddrSpace(const MemSDNode *N);
114+
static unsigned getFromTypeWidthForLoad(const MemSDNode *Mem);
114115
};
115116

116117
class NVPTXDAGToDAGISelLegacy : public SelectionDAGISelLegacy {

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 24 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "NVPTXISelLowering.h"
1515
#include "MCTargetDesc/NVPTXBaseInfo.h"
1616
#include "NVPTX.h"
17+
#include "NVPTXISelDAGToDAG.h"
1718
#include "NVPTXSubtarget.h"
1819
#include "NVPTXTargetMachine.h"
1920
#include "NVPTXTargetObjectFile.h"
@@ -5242,76 +5243,6 @@ static SDValue PerformFADDCombine(SDNode *N,
52425243
return PerformFADDCombineWithOperands(N, N1, N0, DCI, OptLevel);
52435244
}
52445245

5245-
static SDValue PerformANDCombine(SDNode *N,
5246-
TargetLowering::DAGCombinerInfo &DCI) {
5247-
// The type legalizer turns a vector load of i8 values into a zextload to i16
5248-
// registers, optionally ANY_EXTENDs it (if target type is integer),
5249-
// and ANDs off the high 8 bits. Since we turn this load into a
5250-
// target-specific DAG node, the DAG combiner fails to eliminate these AND
5251-
// nodes. Do that here.
5252-
SDValue Val = N->getOperand(0);
5253-
SDValue Mask = N->getOperand(1);
5254-
5255-
if (isa<ConstantSDNode>(Val)) {
5256-
std::swap(Val, Mask);
5257-
}
5258-
5259-
SDValue AExt;
5260-
5261-
// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
5262-
if (Val.getOpcode() == ISD::ANY_EXTEND) {
5263-
AExt = Val;
5264-
Val = Val->getOperand(0);
5265-
}
5266-
5267-
if (Val->getOpcode() == NVPTXISD::LoadV2 ||
5268-
Val->getOpcode() == NVPTXISD::LoadV4) {
5269-
ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
5270-
if (!MaskCnst) {
5271-
// Not an AND with a constant
5272-
return SDValue();
5273-
}
5274-
5275-
uint64_t MaskVal = MaskCnst->getZExtValue();
5276-
if (MaskVal != 0xff) {
5277-
// Not an AND that chops off top 8 bits
5278-
return SDValue();
5279-
}
5280-
5281-
MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
5282-
if (!Mem) {
5283-
// Not a MemSDNode?!?
5284-
return SDValue();
5285-
}
5286-
5287-
EVT MemVT = Mem->getMemoryVT();
5288-
if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
5289-
// We only handle the i8 case
5290-
return SDValue();
5291-
}
5292-
5293-
unsigned ExtType = Val->getConstantOperandVal(Val->getNumOperands() - 1);
5294-
if (ExtType == ISD::SEXTLOAD) {
5295-
// If for some reason the load is a sextload, the and is needed to zero
5296-
// out the high 8 bits
5297-
return SDValue();
5298-
}
5299-
5300-
bool AddTo = false;
5301-
if (AExt.getNode() != nullptr) {
5302-
// Re-insert the ext as a zext.
5303-
Val = DCI.DAG.getNode(ISD::ZERO_EXTEND, SDLoc(N),
5304-
AExt.getValueType(), Val);
5305-
AddTo = true;
5306-
}
5307-
5308-
// If we get here, the AND is unnecessary. Just replace it with the load
5309-
DCI.CombineTo(N, Val, AddTo);
5310-
}
5311-
5312-
return SDValue();
5313-
}
5314-
53155246
static SDValue PerformREMCombine(SDNode *N,
53165247
TargetLowering::DAGCombinerInfo &DCI,
53175248
CodeGenOptLevel OptLevel) {
@@ -5983,8 +5914,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
59835914
return PerformADDCombine(N, DCI, OptLevel);
59845915
case ISD::ADDRSPACECAST:
59855916
return combineADDRSPACECAST(N, DCI);
5986-
case ISD::AND:
5987-
return PerformANDCombine(N, DCI);
59885917
case ISD::SIGN_EXTEND:
59895918
case ISD::ZERO_EXTEND:
59905919
return combineMulWide(N, DCI, OptLevel);
@@ -6609,6 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
66096538
}
66106539
}
66116540

6541+
static void computeKnownBitsForLoadV(const SDValue Op, KnownBits &Known) {
6542+
MemSDNode *LD = cast<MemSDNode>(Op);
6543+
6544+
// We can't do anything without knowing the sign bit.
6545+
auto ExtType = LD->getConstantOperandVal(LD->getNumOperands() - 1);
6546+
if (ExtType == ISD::SEXTLOAD)
6547+
return;
6548+
6549+
// ExtLoading to vector types is weird and may not work well with known bits.
6550+
auto DestVT = LD->getValueType(0);
6551+
if (DestVT.isVector())
6552+
return;
6553+
6554+
assert(Known.getBitWidth() == DestVT.getSizeInBits());
6555+
auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad(LD);
6556+
Known.Zero.setHighBits(Known.getBitWidth() - ElementBitWidth);
6557+
}
6558+
66126559
void NVPTXTargetLowering::computeKnownBitsForTargetNode(
66136560
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
66146561
const SelectionDAG &DAG, unsigned Depth) const {
@@ -6618,6 +6565,11 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
66186565
case NVPTXISD::PRMT:
66196566
computeKnownBitsForPRMT(Op, Known, DAG, Depth);
66206567
break;
6568+
case NVPTXISD::LoadV2:
6569+
case NVPTXISD::LoadV4:
6570+
case NVPTXISD::LoadV8:
6571+
computeKnownBitsForLoadV(Op, Known);
6572+
break;
66216573
default:
66226574
break;
66236575
}

llvm/test/CodeGen/NVPTX/i8x2-instructions.ll

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,34 @@ define <2 x i8> @test_call_2xi8(<2 x i8> %a) {
103103
%res = call <2 x i8> @test_call_2xi8(<2 x i8> %a)
104104
ret <2 x i8> %res
105105
}
106+
107+
define <2 x float> @test_uitofp_2xi8(<2 x i8> %a) {
108+
; O0-LABEL: test_uitofp_2xi8(
109+
; O0: {
110+
; O0-NEXT: .reg .b16 %rs<3>;
111+
; O0-NEXT: .reg .b32 %r<4>;
112+
; O0-EMPTY:
113+
; O0-NEXT: // %bb.0:
114+
; O0-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
115+
; O0-NEXT: mov.b32 %r1, {%rs1, %rs2};
116+
; O0-NEXT: cvt.rn.f32.u16 %r2, %rs2;
117+
; O0-NEXT: cvt.rn.f32.u16 %r3, %rs1;
118+
; O0-NEXT: st.param.v2.b32 [func_retval0], {%r3, %r2};
119+
; O0-NEXT: ret;
120+
;
121+
; O3-LABEL: test_uitofp_2xi8(
122+
; O3: {
123+
; O3-NEXT: .reg .b16 %rs<3>;
124+
; O3-NEXT: .reg .b32 %r<3>;
125+
; O3-EMPTY:
126+
; O3-NEXT: // %bb.0:
127+
; O3-NEXT: ld.param.v2.b8 {%rs1, %rs2}, [test_uitofp_2xi8_param_0];
128+
; O3-NEXT: cvt.rn.f32.u16 %r1, %rs2;
129+
; O3-NEXT: cvt.rn.f32.u16 %r2, %rs1;
130+
; O3-NEXT: st.param.v2.b32 [func_retval0], {%r2, %r1};
131+
; O3-NEXT: ret;
132+
%1 = uitofp <2 x i8> %a to <2 x float>
133+
ret <2 x float> %1
134+
}
106135
;; NOTE: These prefixes are unused and the list is autogenerated. Do not add tests below this line:
107136
; COMMON: {{.*}}

llvm/test/CodeGen/NVPTX/shift-opt.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,17 @@ define <2 x i16> @test_vec(<2 x i16> %x, <2 x i8> %y) {
7171
; CHECK-LABEL: test_vec(
7272
; CHECK: {
7373
; CHECK-NEXT: .reg .b16 %rs<7>;
74-
; CHECK-NEXT: .reg .b32 %r<5>;
74+
; CHECK-NEXT: .reg .b32 %r<4>;
7575
; CHECK-EMPTY:
7676
; CHECK-NEXT: // %bb.0:
7777
; CHECK-NEXT: ld.param.v2.b16 {%rs1, %rs2}, [test_vec_param_0];
7878
; CHECK-NEXT: ld.param.v2.b8 {%rs3, %rs4}, [test_vec_param_1];
7979
; CHECK-NEXT: mov.b32 %r1, {%rs3, %rs4};
80-
; CHECK-NEXT: and.b32 %r2, %r1, 16711935;
8180
; CHECK-NEXT: shr.u16 %rs5, %rs2, 5;
8281
; CHECK-NEXT: shr.u16 %rs6, %rs1, 5;
83-
; CHECK-NEXT: mov.b32 %r3, {%rs6, %rs5};
84-
; CHECK-NEXT: or.b32 %r4, %r3, %r2;
85-
; CHECK-NEXT: st.param.b32 [func_retval0], %r4;
82+
; CHECK-NEXT: mov.b32 %r2, {%rs6, %rs5};
83+
; CHECK-NEXT: or.b32 %r3, %r2, %r1;
84+
; CHECK-NEXT: st.param.b32 [func_retval0], %r3;
8685
; CHECK-NEXT: ret;
8786
%ext = zext <2 x i8> %y to <2 x i16>
8887
%shl = shl <2 x i16> %ext, splat(i16 5)

0 commit comments

Comments
 (0)