@@ -18456,36 +18456,80 @@ static SDValue stripModuloOnShift(const TargetLowering &TLI, SDNode *N,
1845618456 return SDValue();
1845718457}
1845818458
18459- SDValue PPCTargetLowering::combineVectorSHL (SDNode *N,
18460- DAGCombinerInfo &DCI) const {
18459+ SDValue PPCTargetLowering::combineVectorShift (SDNode *N,
18460+ DAGCombinerInfo &DCI) const {
1846118461 EVT VT = N->getValueType(0);
1846218462 assert(VT.isVector() && "Vector type expected.");
1846318463
18464- SDValue N1 = N->getOperand(1);
18465- if (!Subtarget.hasP8Altivec() || N1.getOpcode() != ISD::BUILD_VECTOR ||
18466- !isOperationLegal(ISD::ADD, VT))
18464+ unsigned Opc = N->getOpcode();
18465+ assert((Opc == ISD::SHL || Opc == ISD::SRL || Opc == ISD::SRA) &&
18466+ "Unexpected opcode.");
18467+
18468+ if (!isOperationLegal(Opc, VT))
1846718469 return SDValue();
1846818470
18469- // For 64-bit there is no splat immediate so we want to catch shift by 1 here
18470- // before the BUILD_VECTOR is replaced by a load.
1847118471 EVT EltTy = VT.getScalarType();
18472- if (EltTy != MVT::i64)
18472+ unsigned EltBits = EltTy.getSizeInBits();
18473+ if (EltTy != MVT::i64 && EltTy != MVT::i32)
1847318474 return SDValue();
1847418475
18475- BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(N1);
18476- APInt APSplatBits, APSplatUndef;
18477- unsigned SplatBitSize;
18478- bool HasAnyUndefs;
18479- bool BVNIsConstantSplat =
18480- BVN->isConstantSplat(APSplatBits, APSplatUndef, SplatBitSize,
18481- HasAnyUndefs, 0, !Subtarget.isLittleEndian());
18482- if (!BVNIsConstantSplat || SplatBitSize != EltTy.getSizeInBits())
18476+ SDValue N1 = N->getOperand(1);
18477+ uint64_t SplatBits = 0;
18478+ bool AddSplatCase = false;
18479+ unsigned OpcN1 = N1.getOpcode();
18480+ if (OpcN1 == PPCISD::VADD_SPLAT &&
18481+ N1.getConstantOperandVal(1) == VT.getVectorNumElements()) {
18482+ AddSplatCase = true;
18483+ SplatBits = N1.getConstantOperandVal(0);
18484+ }
18485+
18486+ if (!AddSplatCase) {
18487+ if (OpcN1 != ISD::BUILD_VECTOR)
18488+ return SDValue();
18489+
18490+ unsigned SplatBitSize;
18491+ bool HasAnyUndefs;
18492+ APInt APSplatBits, APSplatUndef;
18493+ BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(N1);
18494+ bool BVNIsConstantSplat =
18495+ BVN->isConstantSplat(APSplatBits, APSplatUndef, SplatBitSize,
18496+ HasAnyUndefs, 0, !Subtarget.isLittleEndian());
18497+ if (!BVNIsConstantSplat || SplatBitSize != EltBits)
18498+ return SDValue();
18499+ SplatBits = APSplatBits.getZExtValue();
18500+ }
18501+
18502+ SDLoc DL(N);
18503+ SDValue N0 = N->getOperand(0);
18504+ // PPC vector shifts by word/double look at only the low 5/6 bits of the
18505+ // shift vector, which means the max value is 31/63. A shift vector of all
18506+ // 1s will be truncated to 31/63, which is useful as vspltiw is limited to
18507+ // -16 to 15 range.
18508+ if (SplatBits == (EltBits - 1)) {
18509+ unsigned NewOpc;
18510+ switch (Opc) {
18511+ case ISD::SHL:
18512+ NewOpc = PPCISD::SHL;
18513+ break;
18514+ case ISD::SRL:
18515+ NewOpc = PPCISD::SRL;
18516+ break;
18517+ case ISD::SRA:
18518+ NewOpc = PPCISD::SRA;
18519+ break;
18520+ }
18521+ SDValue SplatOnes = getCanonicalConstSplat(255, 1, VT, DCI.DAG, DL);
18522+ return DCI.DAG.getNode(NewOpc, DL, VT, N0, SplatOnes);
18523+ }
18524+
18525+ if (Opc != ISD::SHL || !isOperationLegal(ISD::ADD, VT))
1848318526 return SDValue();
18484- uint64_t SplatBits = APSplatBits.getZExtValue();
18485- if (SplatBits != 1)
18527+
18528+ // For 64-bit there is no splat immediate so we want to catch shift by 1 here
18529+ // before the BUILD_VECTOR is replaced by a load.
18530+ if (EltTy != MVT::i64 || SplatBits != 1)
1848618531 return SDValue();
1848718532
18488- SDValue N0 = N->getOperand(0);
1848918533 return DCI.DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
1849018534}
1849118535
@@ -18494,7 +18538,7 @@ SDValue PPCTargetLowering::combineSHL(SDNode *N, DAGCombinerInfo &DCI) const {
1849418538 return Value;
1849518539
1849618540 if (N->getValueType(0).isVector())
18497- return combineVectorSHL (N, DCI);
18541+ return combineVectorShift (N, DCI);
1849818542
1849918543 SDValue N0 = N->getOperand(0);
1850018544 ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -18526,13 +18570,19 @@ SDValue PPCTargetLowering::combineSRA(SDNode *N, DAGCombinerInfo &DCI) const {
1852618570 if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
1852718571 return Value;
1852818572
18573+ if (N->getValueType(0).isVector())
18574+ return combineVectorShift(N, DCI);
18575+
1852918576 return SDValue();
1853018577}
1853118578
1853218579SDValue PPCTargetLowering::combineSRL(SDNode *N, DAGCombinerInfo &DCI) const {
1853318580 if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
1853418581 return Value;
1853518582
18583+ if (N->getValueType(0).isVector())
18584+ return combineVectorShift(N, DCI);
18585+
1853618586 return SDValue();
1853718587}
1853818588
0 commit comments