@@ -580,6 +580,7 @@ namespace {
580580 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
581581 EVT VT, SDValue N0, SDValue N1,
582582 SDNodeFlags Flags = SDNodeFlags());
583+ SDValue foldReductionWithUndefLane(SDNode *N);
583584
584585 SDValue visitShiftByConstant(SDNode *N);
585586
@@ -1347,6 +1348,75 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
13471348 return SDValue();
13481349}
13491350
1351+ // Convert:
1352+ // (op.x2 (vector_shuffle<i,u> A), B) -> <(op A:i, B:0) undef>
1353+ // ...or...
1354+ // (op.x2 (vector_shuffle<u,i> A), B) -> <undef (op A:i, B:1)>
1355+ // ...where i is a valid index and u is poison.
1356+ SDValue DAGCombiner::foldReductionWithUndefLane(SDNode *N) {
1357+ const EVT VectorVT = N->getValueType(0);
1358+
1359+ // Only support 2-packed vectors for now.
1360+ if (!VectorVT.isVector() || VectorVT.isScalableVector()
1361+ || VectorVT.getVectorNumElements() != 2)
1362+ return SDValue();
1363+
1364+ // If the operation is already unsupported, we don't need to do this
1365+ // operation.
1366+ if (!TLI.isOperationLegal(N->getOpcode(), VectorVT))
1367+ return SDValue();
1368+
1369+ // If vector shuffle is supported on the target, this optimization may
1370+ // increase register pressure.
1371+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::VECTOR_SHUFFLE, VectorVT))
1372+ return SDValue();
1373+
1374+ SDLoc DL(N);
1375+
1376+ SDValue ShufOp = N->getOperand(0);
1377+ SDValue VectOp = N->getOperand(1);
1378+ bool Swapped = false;
1379+
1380+ // canonicalize shuffle op
1381+ if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
1382+ std::swap(ShufOp, VectOp);
1383+ Swapped = true;
1384+ }
1385+
1386+ if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
1387+ return SDValue();
1388+
1389+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
1390+ int LiveLane; // exclusively live lane
1391+ for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
1392+ // check if the current lane is live and the other lane is dead
1393+ if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
1394+ ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
1395+ break;
1396+ }
1397+ if (LiveLane == 2)
1398+ return SDValue();
1399+
1400+ const int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
1401+ const EVT ScalarVT = VectorVT.getScalarType();
1402+ SDValue Lanes[2] = {};
1403+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
1404+ if (LaneID == (unsigned)LiveLane) {
1405+ SDValue Operands[2] = {
1406+ DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
1407+ ElementIdx),
1408+ DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
1409+ // preserve the order of operands
1410+ if (Swapped)
1411+ std::swap(Operands[0], Operands[1]);
1412+ LaneVal = DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
1413+ } else {
1414+ LaneVal = DAG.getUNDEF(ScalarVT);
1415+ }
1416+ }
1417+ return DAG.getBuildVector(VectorVT, DL, Lanes);
1418+ }
1419+
13501420SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
13511421 bool AddTo) {
13521422 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -3056,6 +3126,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
30563126 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
30573127 }
30583128
3129+ if (SDValue R = foldReductionWithUndefLane(N))
3130+ return R;
3131+
30593132 return SDValue();
30603133}
30613134
@@ -5999,6 +6072,9 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
59996072 SDLoc(N), VT, N0, N1))
60006073 return SD;
60016074
6075+ if (SDValue SD = foldReductionWithUndefLane(N))
6076+ return SD;
6077+
60026078 // Simplify the operands using demanded-bits information.
60036079 if (SimplifyDemandedBits(SDValue(N, 0)))
60046080 return SDValue(N, 0);
@@ -7267,6 +7343,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
72677343 }
72687344 }
72697345 }
7346+
7347+ if (SDValue R = foldReductionWithUndefLane(N))
7348+ return R;
72707349 }
72717350
72727351 // fold (and x, -1) -> x
@@ -8242,6 +8321,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
82428321 }
82438322 }
82448323 }
8324+
8325+ if (SDValue R = foldReductionWithUndefLane(N))
8326+ return R;
82458327 }
82468328
82478329 // fold (or x, 0) -> x
@@ -9923,6 +10005,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
992310005 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
992410006 return Combined;
992510007
10008+ if (SDValue R = foldReductionWithUndefLane(N))
10009+ return R;
10010+
992610011 return SDValue();
992710012}
992810013
@@ -17529,6 +17614,10 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
1752917614 AddToWorklist(Fused.getNode());
1753017615 return Fused;
1753117616 }
17617+
17618+ if (SDValue R = foldReductionWithUndefLane(N))
17619+ return R;
17620+
1753217621 return SDValue();
1753317622}
1753417623
@@ -17897,6 +17986,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
1789717986 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
1789817987 return R;
1789917988
17989+ if (SDValue R = foldReductionWithUndefLane(N))
17990+ return R;
17991+
1790017992 return SDValue();
1790117993}
1790217994
@@ -19002,6 +19094,9 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
1900219094 Opc, SDLoc(N), VT, N0, N1, Flags))
1900319095 return SD;
1900419096
19097+ if (SDValue SD = foldReductionWithUndefLane(N))
19098+ return SD;
19099+
1900519100 return SDValue();
1900619101}
1900719102
0 commit comments