@@ -6538,11 +6538,10 @@ void NVPTXTargetLowering::computeKnownBitsForTargetNode(
65386538 }
65396539}
65406540
6541- static void getPRMTDemandedBits (const APInt &SelectorVal,
6542- const APInt &DemandedBits, APInt &DemandedLHS,
6543- APInt &DemandedRHS) {
6544- DemandedLHS = APInt (32 , 0 );
6545- DemandedRHS = APInt (32 , 0 );
6541+ static std::pair<APInt, APInt> getPRMTDemandedBits (const APInt &SelectorVal,
6542+ const APInt &DemandedBits) {
6543+ APInt DemandedLHS = APInt (32 , 0 );
6544+ APInt DemandedRHS = APInt (32 , 0 );
65466545
65476546 for (unsigned I : llvm::seq (4 )) {
65486547 if (DemandedBits.extractBits (8 , I * 8 ).isZero ())
@@ -6559,6 +6558,8 @@ static void getPRMTDemandedBits(const APInt &SelectorVal,
65596558 else
65606559 Src.setBits (ByteStart, ByteStart + 8 );
65616560 }
6561+
6562+ return {DemandedLHS, DemandedRHS};
65626563}
65636564
65646565// Replace undef with 0 as this is easier for other optimizations such as
@@ -6576,26 +6577,26 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
65766577 SelectionDAG &DAG,
65776578 const TargetLowering &TLI,
65786579 unsigned Depth) {
6580+ assert (PRMT.getOpcode () == NVPTXISD::PRMT);
65796581 SDValue Op0 = PRMT.getOperand (0 );
65806582 SDValue Op1 = PRMT.getOperand (1 );
6581- ConstantSDNode *Selector = dyn_cast<ConstantSDNode>(PRMT.getOperand (2 ));
6582- unsigned Mode = PRMT.getConstantOperandVal (3 );
6583- if (!Selector)
6583+ auto *SelectorConst = dyn_cast<ConstantSDNode>(PRMT.getOperand (2 ));
6584+ if (!SelectorConst)
65846585 return SDValue ();
65856586
6586- const APInt SelectorVal = getPRMTSelector (Selector->getAPIntValue (), Mode);
6587+ unsigned Mode = PRMT.getConstantOperandVal (3 );
6588+ const APInt Selector = getPRMTSelector (SelectorConst->getAPIntValue (), Mode);
65876589
65886590 // Try to simplify the PRMT to one of the inputs if the used bytes are all
65896591 // from the same input in the correct order.
65906592 const unsigned LeadingBytes = DemandedBits.countLeadingZeros () / 8 ;
65916593 const unsigned SelBits = (4 - LeadingBytes) * 4 ;
6592- if (SelectorVal .getLoBits (SelBits) == APInt (32 , 0x3210 ).getLoBits (SelBits))
6594+ if (Selector .getLoBits (SelBits) == APInt (32 , 0x3210 ).getLoBits (SelBits))
65936595 return Op0;
6594- if (SelectorVal .getLoBits (SelBits) == APInt (32 , 0x7654 ).getLoBits (SelBits))
6596+ if (Selector .getLoBits (SelBits) == APInt (32 , 0x7654 ).getLoBits (SelBits))
65956597 return Op1;
65966598
6597- APInt DemandedLHS, DemandedRHS;
6598- getPRMTDemandedBits (SelectorVal, DemandedBits, DemandedLHS, DemandedRHS);
6599+ auto [DemandedLHS, DemandedRHS] = getPRMTDemandedBits (Selector, DemandedBits);
65996600
66006601 // Attempt to avoid multi-use ops if we don't need anything from them.
66016602 SDValue DemandedOp0 =
@@ -6605,10 +6606,11 @@ static SDValue simplifyDemandedBitsForPRMT(SDValue PRMT,
66056606
66066607 DemandedOp0 = canonicalizePRMTInput (DemandedOp0, DAG);
66076608 DemandedOp1 = canonicalizePRMTInput (DemandedOp1, DAG);
6608- if (DemandedOp0 != Op0 || DemandedOp1 != Op1) {
6609+ if ((DemandedOp0 && DemandedOp0 != Op0) ||
6610+ (DemandedOp1 && DemandedOp1 != Op1)) {
66096611 Op0 = DemandedOp0 ? DemandedOp0 : Op0;
66106612 Op1 = DemandedOp1 ? DemandedOp1 : Op1;
6611- return getPRMT (Op0, Op1, SelectorVal .getZExtValue (), SDLoc (PRMT), DAG);
6613+ return getPRMT (Op0, Op1, Selector .getZExtValue (), SDLoc (PRMT), DAG);
66126614 }
66136615
66146616 return SDValue ();
0 commit comments