@@ -2068,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20682068static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
20692069 SelectionDAG &DAG,
20702070 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2071+ assert (A.getValueType () == MVT::i32 && B.getValueType () == MVT::i32 &&
2072+ Selector.getValueType () == MVT::i32 && " PRMT must have i32 operands" );
20712073 return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
20722074 {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
20732075}
@@ -5872,6 +5874,8 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58725874// details:
58735875// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
58745876static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
5877+ assert (Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
5878+
58755879 if (Mode == NVPTX::PTXPrmtMode::NONE)
58765880 return Selector;
58775881
@@ -5903,6 +5907,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
59035907}
59045908
59055909static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5910+ assert (A.getBitWidth () == 32 && B.getBitWidth () == 32 &&
5911+ Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
59065912 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
59075913 APInt BitField = B.concat (A);
59085914 APInt SelectorVal = getPRMTSelector (Selector, Mode);
@@ -6537,10 +6543,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
65376543 KnownBits BKnown = DAG.computeKnownBits (B, Depth);
65386544
65396545 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6546+ assert (AKnown.getBitWidth () == 32 && BKnown.getBitWidth () == 32 &&
6547+ " PRMT must have i32 operands" );
6548+ assert (Known.getBitWidth () == 32 && " PRMT must have i32 result" );
65406549 KnownBits BitField = BKnown.concat (AKnown);
65416550
65426551 APInt SelectorVal = getPRMTSelector (Selector->getAPIntValue (), Mode);
6543- for (unsigned I : llvm::seq (std::min ( 4U , Known. getBitWidth () / 8 ) )) {
6552+ for (unsigned I : llvm::seq (4 )) {
65446553 APInt Sel = SelectorVal.extractBits (4 , I * 4 );
65456554 unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
65466555 unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
0 commit comments