Skip to content

Commit 10812eb

Browse files
authored
[NVPTX] Assert PRMT operands are of correct type (NFC) (llvm#150104)
1 parent b13bca7 commit 10812eb

File tree

1 file changed

+10
-1
lines changed

1 file changed

+10
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2068,6 +2068,8 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20682068
static SDValue getPRMT(SDValue A, SDValue B, SDValue Selector, SDLoc DL,
20692069
SelectionDAG &DAG,
20702070
unsigned Mode = NVPTX::PTXPrmtMode::NONE) {
2071+
assert(A.getValueType() == MVT::i32 && B.getValueType() == MVT::i32 &&
2072+
Selector.getValueType() == MVT::i32 && "PRMT must have i32 operands");
20712073
return DAG.getNode(NVPTXISD::PRMT, DL, MVT::i32,
20722074
{A, B, Selector, DAG.getConstant(Mode, DL, MVT::i32)});
20732075
}
@@ -5872,6 +5874,8 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58725874
// details:
58735875
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
58745876
static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
5877+
assert(Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
5878+
58755879
if (Mode == NVPTX::PTXPrmtMode::NONE)
58765880
return Selector;
58775881

@@ -5903,6 +5907,8 @@ static APInt getPRMTSelector(const APInt &Selector, unsigned Mode) {
59035907
}
59045908

59055909
static APInt computePRMT(APInt A, APInt B, APInt Selector, unsigned Mode) {
5910+
assert(A.getBitWidth() == 32 && B.getBitWidth() == 32 &&
5911+
Selector.getBitWidth() == 32 && "PRMT must have i32 operands");
59065912
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
59075913
APInt BitField = B.concat(A);
59085914
APInt SelectorVal = getPRMTSelector(Selector, Mode);
@@ -6537,10 +6543,13 @@ static void computeKnownBitsForPRMT(const SDValue Op, KnownBits &Known,
65376543
KnownBits BKnown = DAG.computeKnownBits(B, Depth);
65386544

65396545
// {b, a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}
6546+
assert(AKnown.getBitWidth() == 32 && BKnown.getBitWidth() == 32 &&
6547+
"PRMT must have i32 operands");
6548+
assert(Known.getBitWidth() == 32 && "PRMT must have i32 result");
65406549
KnownBits BitField = BKnown.concat(AKnown);
65416550

65426551
APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
6543-
for (unsigned I : llvm::seq(std::min(4U, Known.getBitWidth() / 8))) {
6552+
for (unsigned I : llvm::seq(4)) {
65446553
APInt Sel = SelectorVal.extractBits(4, I * 4);
65456554
unsigned Idx = Sel.getLoBits(3).getZExtValue();
65466555
unsigned Sign = Sel.getHiBits(1).getZExtValue();

0 commit comments

Comments
 (0)