@@ -18456,36 +18456,80 @@ static SDValue stripModuloOnShift(const TargetLowering &TLI, SDNode *N,
1845618456  return SDValue();
1845718457}
1845818458
18459- SDValue PPCTargetLowering::combineVectorSHL (SDNode *N,
18460-                                             DAGCombinerInfo &DCI) const {
18459+ SDValue PPCTargetLowering::combineVectorShift (SDNode *N,
18460+                                                DAGCombinerInfo &DCI) const {
1846118461  EVT VT = N->getValueType(0);
1846218462  assert(VT.isVector() && "Vector type expected.");
1846318463
18464-   SDValue N1 = N->getOperand(1);
18465-   if (!Subtarget.hasP8Altivec() || N1.getOpcode() != ISD::BUILD_VECTOR ||
18466-       !isOperationLegal(ISD::ADD, VT))
18464+   unsigned Opc = N->getOpcode();
18465+   assert((Opc == ISD::SHL || Opc == ISD::SRL || Opc == ISD::SRA) &&
18466+          "Unexpected opcode.");
18467+ 
18468+   if (!isOperationLegal(Opc, VT))
1846718469    return SDValue();
1846818470
18469-   // For 64-bit there is no splat immediate so we want to catch shift by 1 here
18470-   // before the BUILD_VECTOR is replaced by a load.
1847118471  EVT EltTy = VT.getScalarType();
18472-   if (EltTy != MVT::i64)
18472+   unsigned EltBits = EltTy.getSizeInBits();
18473+   if (EltTy != MVT::i64 && EltTy != MVT::i32)
1847318474    return SDValue();
1847418475
18475-   BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(N1);
18476-   APInt APSplatBits, APSplatUndef;
18477-   unsigned SplatBitSize;
18478-   bool HasAnyUndefs;
18479-   bool BVNIsConstantSplat =
18480-       BVN->isConstantSplat(APSplatBits, APSplatUndef, SplatBitSize,
18481-                            HasAnyUndefs, 0, !Subtarget.isLittleEndian());
18482-   if (!BVNIsConstantSplat || SplatBitSize != EltTy.getSizeInBits())
18476+   SDValue N1 = N->getOperand(1);
18477+   uint64_t SplatBits = 0;
18478+   bool AddSplatCase = false;
18479+   unsigned OpcN1 = N1.getOpcode();
18480+   if (OpcN1 == PPCISD::VADD_SPLAT &&
18481+       N1.getConstantOperandVal(1) == VT.getVectorNumElements()) {
18482+     AddSplatCase = true;
18483+     SplatBits = N1.getConstantOperandVal(0);
18484+   }
18485+ 
18486+   if (!AddSplatCase) {
18487+     if (OpcN1 != ISD::BUILD_VECTOR)
18488+       return SDValue();
18489+ 
18490+     unsigned SplatBitSize;
18491+     bool HasAnyUndefs;
18492+     APInt APSplatBits, APSplatUndef;
18493+     BuildVectorSDNode *BVN = cast<BuildVectorSDNode>(N1);
18494+     bool BVNIsConstantSplat =
18495+         BVN->isConstantSplat(APSplatBits, APSplatUndef, SplatBitSize,
18496+                              HasAnyUndefs, 0, !Subtarget.isLittleEndian());
18497+     if (!BVNIsConstantSplat || SplatBitSize != EltBits)
18498+       return SDValue();
18499+     SplatBits = APSplatBits.getZExtValue();
18500+   }
18501+ 
18502+   SDLoc DL(N);
18503+   SDValue N0 = N->getOperand(0);
18504+   // PPC vector shifts by word/double look at only the low 5/6 bits of the
18505+   // shift vector, which means the max value is 31/63. A shift vector of all
18506+   // 1s will be truncated to 31/63, which is useful as vspltiw is limited to
18507+   // -16 to 15 range.
18508+   if (SplatBits == (EltBits - 1)) {
18509+     unsigned NewOpc;
18510+     switch (Opc) {
18511+     case ISD::SHL:
18512+       NewOpc = PPCISD::SHL;
18513+       break;
18514+     case ISD::SRL:
18515+       NewOpc = PPCISD::SRL;
18516+       break;
18517+     case ISD::SRA:
18518+       NewOpc = PPCISD::SRA;
18519+       break;
18520+     }
18521+     SDValue SplatOnes = getCanonicalConstSplat(255, 1, VT, DCI.DAG, DL);
18522+     return DCI.DAG.getNode(NewOpc, DL, VT, N0, SplatOnes);
18523+   }
18524+ 
18525+   if (Opc != ISD::SHL || !isOperationLegal(ISD::ADD, VT))
1848318526    return SDValue();
18484-   uint64_t SplatBits = APSplatBits.getZExtValue();
18485-   if (SplatBits != 1)
18527+ 
18528+   // For 64-bit there is no splat immediate so we want to catch shift by 1 here
18529+   // before the BUILD_VECTOR is replaced by a load.
18530+   if (EltTy != MVT::i64 || SplatBits != 1)
1848618531    return SDValue();
1848718532
18488-   SDValue N0 = N->getOperand(0);
1848918533  return DCI.DAG.getNode(ISD::ADD, SDLoc(N), VT, N0, N0);
1849018534}
1849118535
@@ -18494,7 +18538,7 @@ SDValue PPCTargetLowering::combineSHL(SDNode *N, DAGCombinerInfo &DCI) const {
1849418538    return Value;
1849518539
1849618540  if (N->getValueType(0).isVector())
18497-     return combineVectorSHL (N, DCI);
18541+     return combineVectorShift (N, DCI);
1849818542
1849918543  SDValue N0 = N->getOperand(0);
1850018544  ConstantSDNode *CN1 = dyn_cast<ConstantSDNode>(N->getOperand(1));
@@ -18526,13 +18570,19 @@ SDValue PPCTargetLowering::combineSRA(SDNode *N, DAGCombinerInfo &DCI) const {
1852618570  if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
1852718571    return Value;
1852818572
18573+   if (N->getValueType(0).isVector())
18574+     return combineVectorShift(N, DCI);
18575+ 
1852918576  return SDValue();
1853018577}
1853118578
1853218579SDValue PPCTargetLowering::combineSRL(SDNode *N, DAGCombinerInfo &DCI) const {
1853318580  if (auto Value = stripModuloOnShift(*this, N, DCI.DAG))
1853418581    return Value;
1853518582
18583+   if (N->getValueType(0).isVector())
18584+     return combineVectorShift(N, DCI);
18585+ 
1853618586  return SDValue();
1853718587}
1853818588
0 commit comments