@@ -2068,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20682068static SDValue getPRMT (SDValue A, SDValue B, SDValue Selector, SDLoc DL,
20692069 SelectionDAG &DAG,
20702070 unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2071+ assert (A.getValueType () == MVT::i32 && B.getValueType () == MVT::i32 &&
2072+ Selector.getValueType () == MVT::i32 && " PRMT must have i32 operands" );
20712073 return DAG.getNode (NVPTXISD::PRMT, DL, MVT::i32 ,
20722074 {A, B, Selector, DAG.getConstant (Mode, DL, MVT::i32 )});
20732075}
@@ -5845,6 +5847,8 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58455847// details:
58465848// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
58475849static APInt getPRMTSelector (const APInt &Selector, unsigned Mode) {
5850+ assert (Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
5851+
58485852 if (Mode == NVPTX::PTXPrmtMode::NONE)
58495853 return Selector;
58505854
@@ -5876,6 +5880,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
58765880}
58775881
58785882static APInt computePRMT (APInt A, APInt B, APInt Selector, unsigned Mode) {
5883+ assert (A.getBitWidth () == 32 && B.getBitWidth () == 32 &&
5884+ Selector.getBitWidth () == 32 && " PRMT must have i32 operands" );
58795885 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
58805886 APInt BitField = B.concat (A);
58815887 APInt SelectorVal = getPRMTSelector (Selector, Mode);
@@ -6510,10 +6516,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
65106516 KnownBits BKnown = DAG.computeKnownBits (B, Depth);
65116517
65126518 // {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6519+ assert (AKnown.getBitWidth () == 32 && BKnown.getBitWidth () == 32 &&
6520+ " PRMT must have i32 operands" );
6521+ assert (Known.getBitWidth () == 32 && " PRMT must have i32 result" );
65136522 KnownBits BitField = BKnown.concat (AKnown);
65146523
65156524 APInt SelectorVal = getPRMTSelector (Selector->getAPIntValue (), Mode);
6516- for (unsigned I : llvm::seq (std::min ( 4U , Known. getBitWidth () / 8 ) )) {
6525+ for (unsigned I : llvm::seq (4 )) {
65176526 APInt Sel = SelectorVal.extractBits (4 , I * 4 );
65186527 unsigned Idx = Sel.getLoBits (3 ).getZExtValue ();
65196528 unsigned Sign = Sel.getHiBits (1 ).getZExtValue ();
0 commit comments