@@ -2068,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2068
2068
static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
2069
2069
SelectionDAG &DAG,
2070
2070
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2071
+ assert (A.getValueType () == MVT::i32 && B.getValueType () == MVT::i32 &&
2072
+ Selector.getValueType () == MVT::i32 && " PRMT must have i32 operands" );
2071
2073
return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
2072
2074
{A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
2073
2075
}
@@ -5872,6 +5874,8 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5872
5874
// details:
5873
5875
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
5874
5876
static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
5877
+ assert (Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
5878
+
5875
5879
if (Mode == NVPTX::PTXPrmtMode::NONE)
5876
5880
return Selector;
5877
5881
@@ -5903,6 +5907,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
5903
5907
}
5904
5908
5905
5909
static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5910
+ assert (A.getBitWidth () == 32 && B.getBitWidth () == 32 &&
5911
+ Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
5906
5912
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
5907
5913
APInt BitField = B.concat (A);
5908
5914
APInt SelectorVal = getPRMTSelector (Selector, Mode);
@@ -6537,10 +6543,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
6537
6543
KnownBits BKnown = DAG.computeKnownBits (B, Depth);
6538
6544
6539
6545
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6546
+ assert (AKnown.getBitWidth () == 32 && BKnown.getBitWidth () == 32 &&
6547
+ " PRMT must have i32 operands" );
6548
+ assert (Known.getBitWidth () == 32 && " PRMT must have i32 result" );
6540
6549
KnownBits BitField = BKnown.concat (AKnown);
6541
6550
6542
6551
APInt SelectorVal = getPRMTSelector (Selector->getAPIntValue (), Mode);
6543
- for (unsigned I : llvm::seq (std::min ( 4U , Known. getBitWidth () / 8 ) )) {
6552
+ for (unsigned I : llvm::seq (4 )) {
6544
6553
APInt Sel = SelectorVal.extractBits (4 , I * 4 );
6545
6554
unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
6546
6555
unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
0 commit comments