@@ -582,6 +582,7 @@ namespace {
582582 SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
583583 EVT VT, SDValue N0, SDValue N1,
584584 SDNodeFlags Flags = SDNodeFlags());
585+ SDValue foldReductionWithUndefLane(SDNode *N);
585586
586587 SDValue visitShiftByConstant(SDNode *N);
587588
@@ -1349,6 +1350,75 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
13491350 return SDValue();
13501351}
13511352
1353+ // Convert:
1354+ // (op.x2 (vector_shuffle<i,u> A), B) -> <(op A:i, B:0) undef>
1355+ // ...or...
1356+ // (op.x2 (vector_shuffle<u,i> A), B) -> <undef (op A:i, B:1)>
1357+ // ...where i is a valid index and u is poison.
1358+ SDValue DAGCombiner::foldReductionWithUndefLane(SDNode *N) {
1359+ const EVT VectorVT = N->getValueType(0);
1360+
1361+ // Only support 2-packed vectors for now.
1362+ if (!VectorVT.isVector() || VectorVT.isScalableVector()
1363+ || VectorVT.getVectorNumElements() != 2)
1364+ return SDValue();
1365+
1366+ // If the operation is already unsupported, we don't need to do this
1367+ // operation.
1368+ if (!TLI.isOperationLegal(N->getOpcode(), VectorVT))
1369+ return SDValue();
1370+
1371+ // If vector shuffle is supported on the target, this optimization may
1372+ // increase register pressure.
1373+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::VECTOR_SHUFFLE, VectorVT))
1374+ return SDValue();
1375+
1376+ SDLoc DL(N);
1377+
1378+ SDValue ShufOp = N->getOperand(0);
1379+ SDValue VectOp = N->getOperand(1);
1380+ bool Swapped = false;
1381+
1382+ // canonicalize shuffle op
1383+ if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
1384+ std::swap(ShufOp, VectOp);
1385+ Swapped = true;
1386+ }
1387+
1388+ if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
1389+ return SDValue();
1390+
1391+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
1392+ int LiveLane; // exclusively live lane
1393+ for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
1394+ // check if the current lane is live and the other lane is dead
1395+ if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
1396+ ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
1397+ break;
1398+ }
1399+ if (LiveLane == 2)
1400+ return SDValue();
1401+
1402+ const int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
1403+ const EVT ScalarVT = VectorVT.getScalarType();
1404+ SDValue Lanes[2] = {};
1405+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
1406+ if (LaneID == (unsigned)LiveLane) {
1407+ SDValue Operands[2] = {
1408+ DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
1409+ ElementIdx),
1410+ DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
1411+ // preserve the order of operands
1412+ if (Swapped)
1413+ std::swap(Operands[0], Operands[1]);
1414+ LaneVal = DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
1415+ } else {
1416+ LaneVal = DAG.getUNDEF(ScalarVT);
1417+ }
1418+ }
1419+ return DAG.getBuildVector(VectorVT, DL, Lanes);
1420+ }
1421+
13521422SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
13531423 bool AddTo) {
13541424 assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -3058,6 +3128,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
30583128 return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
30593129 }
30603130
3131+ if (SDValue R = foldReductionWithUndefLane(N))
3132+ return R;
3133+
30613134 return SDValue();
30623135}
30633136
@@ -6001,6 +6074,9 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
60016074 SDLoc(N), VT, N0, N1))
60026075 return SD;
60036076
6077+ if (SDValue SD = foldReductionWithUndefLane(N))
6078+ return SD;
6079+
60046080 // Simplify the operands using demanded-bits information.
60056081 if (SimplifyDemandedBits(SDValue(N, 0)))
60066082 return SDValue(N, 0);
@@ -7301,6 +7377,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
73017377 }
73027378 }
73037379 }
7380+
7381+ if (SDValue R = foldReductionWithUndefLane(N))
7382+ return R;
73047383 }
73057384
73067385 // fold (and x, -1) -> x
@@ -8260,6 +8339,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
82608339 }
82618340 }
82628341 }
8342+
8343+ if (SDValue R = foldReductionWithUndefLane(N))
8344+ return R;
82638345 }
82648346
82658347 // fold (or x, 0) -> x
@@ -9941,6 +10023,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
994110023 if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
994210024 return Combined;
994310025
10026+ if (SDValue R = foldReductionWithUndefLane(N))
10027+ return R;
10028+
994410029 return SDValue();
994510030}
994610031
@@ -17557,6 +17642,10 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
1755717642 AddToWorklist(Fused.getNode());
1755817643 return Fused;
1755917644 }
17645+
17646+ if (SDValue R = foldReductionWithUndefLane(N))
17647+ return R;
17648+
1756017649 return SDValue();
1756117650}
1756217651
@@ -17925,6 +18014,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
1792518014 if (SDValue R = combineFMulOrFDivWithIntPow2(N))
1792618015 return R;
1792718016
18017+ if (SDValue R = foldReductionWithUndefLane(N))
18018+ return R;
18019+
1792818020 return SDValue();
1792918021}
1793018022
@@ -19030,6 +19122,9 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
1903019122 Opc, SDLoc(N), VT, N0, N1, Flags))
1903119123 return SD;
1903219124
19125+ if (SDValue SD = foldReductionWithUndefLane(N))
19126+ return SD;
19127+
1903319128 return SDValue();
1903419129}
1903519130
0 commit comments