diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 42ef817e01456..b58dbe5e6e7f0 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -56522,6 +56522,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, SDValue Base = GorS->getBasePtr(); SDValue Scale = GorS->getScale(); EVT IndexVT = Index.getValueType(); + EVT IndexSVT = IndexVT.getVectorElementType(); const TargetLowering &TLI = DAG.getTargetLoweringInfo(); if (DCI.isBeforeLegalize()) { @@ -56558,41 +56559,51 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG, } EVT PtrVT = TLI.getPointerTy(DAG.getDataLayout()); - // Try to move splat constant adders from the index operand to the base + + // Try to move splat adders from the index operand to the base // pointer operand. Taking care to multiply by the scale. We can only do // this when index element type is the same as the pointer type. // Otherwise we need to be sure the math doesn't wrap before the scale. - if (Index.getOpcode() == ISD::ADD && - IndexVT.getVectorElementType() == PtrVT && isa(Scale)) { + if (Index.getOpcode() == ISD::ADD && IndexSVT == PtrVT && + isa(Scale)) { uint64_t ScaleAmt = Scale->getAsZExtVal(); - if (auto *BV = dyn_cast(Index.getOperand(1))) { - BitVector UndefElts; - if (ConstantSDNode *C = BV->getConstantSplatNode(&UndefElts)) { - // FIXME: Allow non-constant? - if (UndefElts.none()) { - // Apply the scale. - APInt Adder = C->getAPIntValue() * ScaleAmt; - // Add it to the existing base. - Base = DAG.getNode(ISD::ADD, DL, PtrVT, Base, - DAG.getConstant(Adder, DL, PtrVT)); - Index = Index.getOperand(0); - return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); - } - } - // It's also possible base is just a constant. In that case, just - // replace it with 0 and move the displacement into the index. - if (BV->isConstant() && isa(Base) && - isOneConstant(Scale)) { - SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base); - // Combine the constant build_vector and the constant base. - Splat = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(1), Splat); - // Add to the LHS of the original Index add. - Index = DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(0), Splat); - Base = DAG.getConstant(0, DL, Base.getValueType()); - return rebuildGatherScatter(GorS, Index, Base, Scale, DAG); + for (unsigned I = 0; I != 2; ++I) + if (auto *BV = dyn_cast(Index.getOperand(I))) { + BitVector UndefElts; + if (SDValue Splat = BV->getSplatValue(&UndefElts)) { + if (UndefElts.none()) { + // If the splat value is constant we can add the scaled splat value + // to the existing base. + if (auto *C = dyn_cast(Splat)) { + APInt Adder = C->getAPIntValue() * ScaleAmt; + SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, + DAG.getConstant(Adder, DL, PtrVT)); + SDValue NewIndex = Index.getOperand(1 - I); + return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG); + } + // For non-constant cases, limit this to non-scaled cases. + if (ScaleAmt == 1) { + SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat); + SDValue NewIndex = Index.getOperand(1 - I); + return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG); + } + } + } + // It's also possible base is just a constant. In that case, just + // replace it with 0 and move the displacement into the index. + if (ScaleAmt == 1 && BV->isConstant() && isa(Base)) { + SDValue Splat = DAG.getSplatBuildVector(IndexVT, DL, Base); + // Combine the constant build_vector and the constant base. + Splat = + DAG.getNode(ISD::ADD, DL, IndexVT, Index.getOperand(I), Splat); + // Add to the other half of the original Index add. + SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT, + Index.getOperand(1 - I), Splat); + SDValue NewBase = DAG.getConstant(0, DL, PtrVT); + return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG); + } } - } } if (DCI.isBeforeLegalizeOps()) { diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll index 5effb18fb6aa6..46e589b7b1be9 100644 --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -5028,12 +5028,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p ; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-KNL-NEXT: movl {{[0-9]+}}(%esp), %ecx ; X86-KNL-NEXT: vpslld $4, (%ecx), %zmm2 -; X86-KNL-NEXT: vpbroadcastd %eax, %zmm0 -; X86-KNL-NEXT: vpaddd %zmm2, %zmm0, %zmm3 ; X86-KNL-NEXT: kmovw %k1, %k2 ; X86-KNL-NEXT: vmovaps %zmm1, %zmm0 ; X86-KNL-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2} -; X86-KNL-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1} +; X86-KNL-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1} ; X86-KNL-NEXT: retl ; ; X64-SKX-SMALL-LABEL: test_gather_16f32_mask_index_pair: @@ -5097,12 +5095,10 @@ define {<16 x float>, <16 x float>} @test_gather_16f32_mask_index_pair(ptr %x, p ; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %ecx ; X86-SKX-NEXT: vpslld $4, (%ecx), %zmm2 -; X86-SKX-NEXT: vpbroadcastd %eax, %zmm0 -; X86-SKX-NEXT: vpaddd %zmm2, %zmm0, %zmm3 ; X86-SKX-NEXT: kmovw %k1, %k2 ; X86-SKX-NEXT: vmovaps %zmm1, %zmm0 ; X86-SKX-NEXT: vgatherdps (%eax,%zmm2), %zmm0 {%k2} -; X86-SKX-NEXT: vgatherdps 4(,%zmm3), %zmm1 {%k1} +; X86-SKX-NEXT: vgatherdps 4(%eax,%zmm2), %zmm1 {%k1} ; X86-SKX-NEXT: retl %wide.load = load <16 x i32>, ptr %arr, align 4 %and = and <16 x i32> %wide.load,