@@ -22743,68 +22743,49 @@ SDValue DAGCombiner::scalarizeExtractedVectorLoad(SDNode *EVE, EVT InVecVT,
2274322743
2274422744/// Transform a vector binary operation into a scalar binary operation by moving
2274522745/// the math/logic after an extract element of a vector.
22746- static bool scalarizeExtractedBinOpCommon(SDNode *ExtElt, SelectionDAG &DAG,
22747- const SDLoc &DL, bool IsSetCC,
22748- SDValue &ScalarOp1,
22749- SDValue &ScalarOp2) {
22746+ static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22747+ const SDLoc &DL) {
22748+ const TargetLowering &TLI = DAG.getTargetLoweringInfo();
2275022749 SDValue Vec = ExtElt->getOperand(0);
2275122750 SDValue Index = ExtElt->getOperand(1);
2275222751 auto *IndexC = dyn_cast<ConstantSDNode>(Index);
22753- if (!IndexC || !Vec.hasOneUse() || Vec->getNumValues() != 1)
22754- return false;
22752+ if (!IndexC ||
22753+ (!TLI.isBinOp(Vec.getOpcode()) && Vec.getOpcode() != ISD::SETCC) ||
22754+ !Vec.hasOneUse() || Vec->getNumValues() != 1)
22755+ return SDValue();
22756+
22757+ EVT ResVT = ExtElt->getValueType(0);
22758+ if (Vec.getOpcode() == ISD::SETCC &&
22759+ ResVT != Vec.getValueType().getVectorElementType())
22760+ return SDValue();
22761+
22762+ // Targets may want to avoid this to prevent an expensive register transfer.
22763+ if (!TLI.shouldScalarizeBinop(Vec))
22764+ return SDValue();
2275522765
2275622766 // Extracting an element of a vector constant is constant-folded, so this
2275722767 // transform is just replacing a vector op with a scalar op while moving the
2275822768 // extract.
2275922769 SDValue Op0 = Vec.getOperand(0);
2276022770 SDValue Op1 = Vec.getOperand(1);
2276122771 APInt SplatVal;
22762- if (isAnyConstantBuildVector(Op0, true) ||
22763- ISD::isConstantSplatVector(Op0.getNode(), SplatVal) ||
22764- isAnyConstantBuildVector(Op1, true) ||
22765- ISD::isConstantSplatVector(Op1.getNode(), SplatVal)) {
22766- // extractelt (binop X, C), IndexC --> binop (extractelt X, IndexC), C'
22767- // extractelt (binop C, X), IndexC --> binop C', (extractelt X, IndexC)
22768- // extractelt (setcc X, C, op), IndexC -> setcc (extractelt X, IndexC)), C
22769- // extractelt (setcc C, X, op), IndexC -> setcc (extractelt IndexC, X)), C
22770- EVT VT = Op0->getValueType(0).getVectorElementType();
22771- ScalarOp1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op0, Index);
22772- ScalarOp2 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, VT, Op1, Index);
22773- return true;
22774- }
22775-
22776- return false;
22777- }
22778-
22779- static SDValue scalarizeExtractedBinOp(SDNode *ExtElt, SelectionDAG &DAG,
22780- const SDLoc &DL) {
22781- SDValue Op1, Op2;
22782- SDValue Vec = ExtElt->getOperand(0);
22783- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22784- if (!TLI.isBinOp(Vec.getOpcode()) || !TLI.shouldScalarizeBinop(Vec))
22785- return SDValue();
22786-
22787- if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, false, Op1, Op2))
22772+ if (!isAnyConstantBuildVector(Op0, true) &&
22773+ !ISD::isConstantSplatVector(Op0.getNode(), SplatVal) &&
22774+ !isAnyConstantBuildVector(Op1, true) &&
22775+ !ISD::isConstantSplatVector(Op1.getNode(), SplatVal))
2278822776 return SDValue();
2278922777
22790- EVT VT = ExtElt->getValueType(0);
22791- return DAG.getNode(Vec.getOpcode(), DL, VT, Op1, Op2);
22792- }
22778+ // extractelt (op X, C), IndexC --> op (extractelt X, IndexC), C'
22779+ // extractelt (op C, X), IndexC --> op C', (extractelt X, IndexC)
22780+ EVT OpVT = Op0->getValueType(0).getVectorElementType();
22781+ Op0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op0, Index);
22782+ Op1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, OpVT, Op1, Index);
2279322783
22794- static SDValue scalarizeExtractedSetCC(SDNode *ExtElt, SelectionDAG &DAG,
22795- const SDLoc &DL) {
22796- SDValue Op1, Op2;
22797- SDValue Vec = ExtElt->getOperand(0);
22798- const TargetLowering &TLI = DAG.getTargetLoweringInfo();
22799- if (Vec.getOpcode() != ISD::SETCC || !TLI.shouldScalarizeSetCC(Vec))
22800- return SDValue();
22801-
22802- if (!scalarizeExtractedBinOpCommon(ExtElt, DAG, DL, true, Op1, Op2))
22803- return SDValue();
22804-
22805- EVT VT = ExtElt->getValueType(0);
22806- return DAG.getSetCC(DL, VT, Op1, Op2,
22807- cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22784+ if (Vec.getOpcode() == ISD::SETCC)
22785+ return DAG.getSetCC(DL, ResVT, Op0, Op1,
22786+ cast<CondCodeSDNode>(Vec->getOperand(2))->get());
22787+ else
22788+ return DAG.getNode(Vec.getOpcode(), DL, ResVT, Op0, Op1);
2280822789}
2280922790
2281022791// Given a ISD::EXTRACT_VECTOR_ELT, which is a glorified bit sequence extract,
@@ -23040,11 +23021,6 @@ SDValue DAGCombiner::visitEXTRACT_VECTOR_ELT(SDNode *N) {
2304023021 if (SDValue BO = scalarizeExtractedBinOp(N, DAG, DL))
2304123022 return BO;
2304223023
23043- // extract (setcc x, splat(y)), i -> setcc (extract x, i)), y
23044- if (ScalarVT == VecVT.getVectorElementType())
23045- if (SDValue SetCC = scalarizeExtractedSetCC(N, DAG, DL))
23046- return SetCC;
23047-
2304823024 if (VecVT.isScalableVector())
2304923025 return SDValue();
2305023026
0 commit comments