14
14
#include " NVPTXISelLowering.h"
15
15
#include " MCTargetDesc/NVPTXBaseInfo.h"
16
16
#include " NVPTX.h"
17
+ #include " NVPTXISelDAGToDAG.h"
17
18
#include " NVPTXSubtarget.h"
18
19
#include " NVPTXTargetMachine.h"
19
20
#include " NVPTXTargetObjectFile.h"
@@ -5242,76 +5243,6 @@ static SDValue PerformFADDCombine(SDNode *N,
5242
5243
return PerformFADDCombineWithOperands (N, N1, N0, DCI, OptLevel);
5243
5244
}
5244
5245
5245
- static SDValue PerformANDCombine (SDNode *N,
5246
- TargetLowering::DAGCombinerInfo &DCI) {
5247
- // The type legalizer turns a vector load of i8 values into a zextload to i16
5248
- // registers, optionally ANY_EXTENDs it (if target type is integer),
5249
- // and ANDs off the high 8 bits. Since we turn this load into a
5250
- // target-specific DAG node, the DAG combiner fails to eliminate these AND
5251
- // nodes. Do that here.
5252
- SDValue Val = N->getOperand (0 );
5253
- SDValue Mask = N->getOperand (1 );
5254
-
5255
- if (isa<ConstantSDNode>(Val)) {
5256
- std::swap (Val, Mask);
5257
- }
5258
-
5259
- SDValue AExt;
5260
-
5261
- // Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
5262
- if (Val.getOpcode () == ISD::ANY_EXTEND) {
5263
- AExt = Val;
5264
- Val = Val->getOperand (0 );
5265
- }
5266
-
5267
- if (Val->getOpcode () == NVPTXISD::LoadV2 ||
5268
- Val->getOpcode () == NVPTXISD::LoadV4) {
5269
- ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
5270
- if (!MaskCnst) {
5271
- // Not an AND with a constant
5272
- return SDValue ();
5273
- }
5274
-
5275
- uint64_t MaskVal = MaskCnst->getZExtValue ();
5276
- if (MaskVal != 0xff ) {
5277
- // Not an AND that chops off top 8 bits
5278
- return SDValue ();
5279
- }
5280
-
5281
- MemSDNode *Mem = dyn_cast<MemSDNode>(Val);
5282
- if (!Mem) {
5283
- // Not a MemSDNode?!?
5284
- return SDValue ();
5285
- }
5286
-
5287
- EVT MemVT = Mem->getMemoryVT ();
5288
- if (MemVT != MVT::v2i8 && MemVT != MVT::v4i8) {
5289
- // We only handle the i8 case
5290
- return SDValue ();
5291
- }
5292
-
5293
- unsigned ExtType = Val->getConstantOperandVal (Val->getNumOperands () - 1 );
5294
- if (ExtType == ISD::SEXTLOAD) {
5295
- // If for some reason the load is a sextload, the and is needed to zero
5296
- // out the high 8 bits
5297
- return SDValue ();
5298
- }
5299
-
5300
- bool AddTo = false ;
5301
- if (AExt.getNode () != nullptr ) {
5302
- // Re-insert the ext as a zext.
5303
- Val = DCI.DAG .getNode (ISD::ZERO_EXTEND, SDLoc (N),
5304
- AExt.getValueType (), Val);
5305
- AddTo = true ;
5306
- }
5307
-
5308
- // If we get here, the AND is unnecessary. Just replace it with the load
5309
- DCI.CombineTo (N, Val, AddTo);
5310
- }
5311
-
5312
- return SDValue ();
5313
- }
5314
-
5315
5246
static SDValue PerformREMCombine (SDNode *N,
5316
5247
TargetLowering::DAGCombinerInfo &DCI,
5317
5248
CodeGenOptLevel OptLevel) {
@@ -5983,8 +5914,6 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5983
5914
return PerformADDCombine (N, DCI, OptLevel);
5984
5915
case ISD::ADDRSPACECAST:
5985
5916
return combineADDRSPACECAST (N, DCI);
5986
- case ISD::AND:
5987
- return PerformANDCombine (N, DCI);
5988
5917
case ISD::SIGN_EXTEND:
5989
5918
case ISD::ZERO_EXTEND:
5990
5919
return combineMulWide (N, DCI, OptLevel);
@@ -6609,6 +6538,24 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6609
6538
}
6610
6539
}
6611
6540
6541
+ static void computeKnownBitsForLoadV (const SDValue Op, KnownBits &Known) {
6542
+ MemSDNode *LD = cast<MemSDNode>(Op);
6543
+
6544
+ // We can't do anything without knowing the sign bit.
6545
+ auto ExtType = LD->getConstantOperandVal (LD->getNumOperands () - 1 );
6546
+ if (ExtType == ISD::SEXTLOAD)
6547
+ return ;
6548
+
6549
+ // ExtLoading to vector types is weird and may not work well with known bits.
6550
+ auto DestVT = LD->getValueType (0 );
6551
+ if (DestVT.isVector ())
6552
+ return ;
6553
+
6554
+ assert (Known.getBitWidth () == DestVT.getSizeInBits ());
6555
+ auto ElementBitWidth = NVPTXDAGToDAGISel::getFromTypeWidthForLoad (LD);
6556
+ Known.Zero .setHighBits (Known.getBitWidth () - ElementBitWidth);
6557
+ }
6558
+
6612
6559
void NVPTXTargetLowering::computeKnownBitsForTargetNode (
6613
6560
const SDValue Op, KnownBits &Known, const APInt &DemandedElts,
6614
6561
const SelectionDAG &DAG, unsigned Depth) const {
@@ -6618,6 +6565,11 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
6618
6565
case NVPTXISD::PRMT:
6619
6566
computeKnownBitsForPRMT (Op, Known, DAG, Depth);
6620
6567
break ;
6568
+ case NVPTXISD::LoadV2:
6569
+ case NVPTXISD::LoadV4:
6570
+ case NVPTXISD::LoadV8:
6571
+ computeKnownBitsForLoadV (Op, Known);
6572
+ break ;
6621
6573
default :
6622
6574
break ;
6623
6575
}
0 commit comments