Skip to content

Commit 7d8f15e

Browse files
committed
address comments
1 parent 2cdc799 commit 7d8f15e

File tree

1 file changed

+17
-15
lines changed

1 file changed

+17
-15
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6538,11 +6538,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
65386538
}
65396539
}
65406540

6541-
static void getPRMTDemandedBits(const APInt &SelectorVal,
6542-
const APInt &DemandedBits, APInt &DemandedLHS,
6543-
APInt &DemandedRHS) {
6544-
DemandedLHS = APInt(32, 0);
6545-
DemandedRHS = APInt(32, 0);
6541+
static std::pair<APInt, APInt> getPRMTDemandedBits(const APInt &SelectorVal,
6542+
const APInt &DemandedBits) {
6543+
APInt DemandedLHS = APInt(32, 0);
6544+
APInt DemandedRHS = APInt(32, 0);
65466545

65476546
for (unsigned I : llvm::seq(4)) {
65486547
if (DemandedBits.extractBits(8, I * 8).isZero())
@@ -6559,6 +6558,8 @@ static void getPRMTDemandedBits(const APInt &SelectorVal,
65596558
else
65606559
Src.setBits(ByteStart, ByteStart + 8);
65616560
}
6561+
6562+
return {DemandedLHS, DemandedRHS};
65626563
}
65636564

65646565
// Replace undef with 0 as this is easier for other optimizations such as
@@ -6576,26 +6577,26 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
65766577
SelectionDAG &DAG,
65776578
const TargetLowering &TLI,
65786579
unsigned Depth) {
6580+
assert(PRMT.getOpcode() == NVPTXISD::PRMT);
65796581
SDValue Op0 = PRMT.getOperand(0);
65806582
SDValue Op1 = PRMT.getOperand(1);
6581-
ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
6582-
unsigned Mode = PRMT.getConstantOperandVal(3);
6583-
if (!Selector)
6583+
auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand(2));
6584+
if (!SelectorConst)
65846585
return SDValue();
65856586

6586-
const APInt SelectorVal = getPRMTSelector(Selector->getAPIntValue(), Mode);
6587+
unsigned Mode = PRMT.getConstantOperandVal(3);
6588+
const APInt Selector = getPRMTSelector(SelectorConst->getAPIntValue(), Mode);
65876589

65886590
// Try to simplify the PRMT to one of the inputs if the used bytes are all
65896591
// from the same input in the correct order.
65906592
const unsigned LeadingBytes = DemandedBits.countLeadingZeros() / 8;
65916593
const unsigned SelBits = (4 - LeadingBytes) * 4;
6592-
if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
6594+
if (Selector.getLoBits(SelBits) == APInt(32, 0x3210).getLoBits(SelBits))
65936595
return Op0;
6594-
if (SelectorVal.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
6596+
if (Selector.getLoBits(SelBits) == APInt(32, 0x7654).getLoBits(SelBits))
65956597
return Op1;
65966598

6597-
APInt DemandedLHS, DemandedRHS;
6598-
getPRMTDemandedBits(SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
6599+
auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits(Selector, DemandedBits);
65996600

66006601
// Attempt to avoid multi-use ops if we don't need anything from them.
66016602
SDValue DemandedOp0 =
@@ -6605,10 +6606,11 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
66056606

66066607
DemandedOp0 = canonicalizePRMTInput(DemandedOp0, DAG);
66076608
DemandedOp1 = canonicalizePRMTInput(DemandedOp1, DAG);
6608-
if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
6609+
if ((DemandedOp0 && DemandedOp0 != Op0) ||
6610+
(DemandedOp1 && DemandedOp1 != Op1)) {
66096611
Op0 = DemandedOp0 ? DemandedOp0 : Op0;
66106612
Op1 = DemandedOp1 ? DemandedOp1 : Op1;
6611-
return getPRMT(Op0, Op1, SelectorVal.getZExtValue(), SDLoc(PRMT), DAG);
6613+
return getPRMT(Op0, Op1, Selector.getZExtValue(), SDLoc(PRMT), DAG);
66126614
}
66136615

66146616
return SDValue();

0 commit comments

Comments
 (0)