From 37125e36534b0d2b0786da7b881c8f3969b18a2c Mon Sep 17 00:00:00 2001 From: Philip Reames Date: Thu, 26 Jun 2025 08:16:49 -0700 Subject: [PATCH] [CFG] Simplify gather/scatter splat pointer matching The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version. The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches. --- llvm/lib/CodeGen/CodeGenPrepare.cpp | 178 ++++++++---------- .../SelectionDAG/SelectionDAGBuilder.cpp | 11 +- .../test/CodeGen/X86/masked_gather_scatter.ll | 8 +- .../gather-scatter-opt-inseltpoison.ll | 6 +- .../AArch64/gather-scatter-opt.ll | 6 +- .../X86/gather-scatter-opt-inseltpoison.ll | 6 +- .../CodeGenPrepare/X86/gather-scatter-opt.ll | 6 +- 7 files changed, 99 insertions(+), 122 deletions(-) diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp index 43574a54c37dd..8a0e26bde23ff 100644 --- a/llvm/lib/CodeGen/CodeGenPrepare.cpp +++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp @@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, /// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find /// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can -/// only handle a 2 operand GEP in the same basic block or a splat constant +/// only handle a 2 operand GEP in the same basic block or a canonical splat /// vector. The 2 operands to the GEP must have a scalar pointer and a vector /// index. /// @@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr, /// zero index. bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst, Value *Ptr) { - Value *NewAddr; + const auto *GEP = dyn_cast(Ptr); + if (!GEP) + return false; - if (const auto *GEP = dyn_cast(Ptr)) { - // Don't optimize GEPs that don't have indices. - if (!GEP->hasIndices()) - return false; + // Don't optimize GEPs that don't have indices. + if (!GEP->hasIndices()) + return false; - // If the GEP and the gather/scatter aren't in the same BB, don't optimize. - // FIXME: We should support this by sinking the GEP. - if (MemoryInst->getParent() != GEP->getParent()) - return false; + // If the GEP and the gather/scatter aren't in the same BB, don't optimize. + // FIXME: We should support this by sinking the GEP. + if (MemoryInst->getParent() != GEP->getParent()) + return false; - SmallVector Ops(GEP->operands()); + SmallVector Ops(GEP->operands()); - bool RewriteGEP = false; + bool RewriteGEP = false; - if (Ops[0]->getType()->isVectorTy()) { - Ops[0] = getSplatValue(Ops[0]); - if (!Ops[0]) - return false; - RewriteGEP = true; - } + if (Ops[0]->getType()->isVectorTy()) { + Ops[0] = getSplatValue(Ops[0]); + if (!Ops[0]) + return false; + RewriteGEP = true; + } - unsigned FinalIndex = Ops.size() - 1; + unsigned FinalIndex = Ops.size() - 1; - // Ensure all but the last index is 0. - // FIXME: This isn't strictly required. All that's required is that they are - // all scalars or splats. - for (unsigned i = 1; i < FinalIndex; ++i) { - auto *C = dyn_cast(Ops[i]); - if (!C) - return false; - if (isa(C->getType())) - C = C->getSplatValue(); - auto *CI = dyn_cast_or_null(C); - if (!CI || !CI->isZero()) - return false; - // Scalarize the index if needed. - Ops[i] = CI; - } - - // Try to scalarize the final index. - if (Ops[FinalIndex]->getType()->isVectorTy()) { - if (Value *V = getSplatValue(Ops[FinalIndex])) { - auto *C = dyn_cast(V); - // Don't scalarize all zeros vector. - if (!C || !C->isZero()) { - Ops[FinalIndex] = V; - RewriteGEP = true; - } + // Ensure all but the last index is 0. + // FIXME: This isn't strictly required. All that's required is that they are + // all scalars or splats. + for (unsigned i = 1; i < FinalIndex; ++i) { + auto *C = dyn_cast(Ops[i]); + if (!C) + return false; + if (isa(C->getType())) + C = C->getSplatValue(); + auto *CI = dyn_cast_or_null(C); + if (!CI || !CI->isZero()) + return false; + // Scalarize the index if needed. + Ops[i] = CI; + } + + // Try to scalarize the final index. + if (Ops[FinalIndex]->getType()->isVectorTy()) { + if (Value *V = getSplatValue(Ops[FinalIndex])) { + auto *C = dyn_cast(V); + // Don't scalarize all zeros vector. + if (!C || !C->isZero()) { + Ops[FinalIndex] = V; + RewriteGEP = true; } } + } - // If we made any changes or the we have extra operands, we need to generate - // new instructions. - if (!RewriteGEP && Ops.size() == 2) - return false; + // If we made any changes or the we have extra operands, we need to generate + // new instructions. + if (!RewriteGEP && Ops.size() == 2) + return false; - auto NumElts = cast(Ptr->getType())->getElementCount(); + auto NumElts = cast(Ptr->getType())->getElementCount(); - IRBuilder<> Builder(MemoryInst); + IRBuilder<> Builder(MemoryInst); - Type *SourceTy = GEP->getSourceElementType(); - Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); + Type *SourceTy = GEP->getSourceElementType(); + Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType()); - // If the final index isn't a vector, emit a scalar GEP containing all ops - // and a vector GEP with all zeroes final index. - if (!Ops[FinalIndex]->getType()->isVectorTy()) { - NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front()); - auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts); - auto *SecondTy = GetElementPtrInst::getIndexedType( - SourceTy, ArrayRef(Ops).drop_front()); - NewAddr = - Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy)); - } else { - Value *Base = Ops[0]; - Value *Index = Ops[FinalIndex]; - - // Create a scalar GEP if there are more than 2 operands. - if (Ops.size() != 2) { - // Replace the last index with 0. - Ops[FinalIndex] = - Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType()); - Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front()); - SourceTy = GetElementPtrInst::getIndexedType( - SourceTy, ArrayRef(Ops).drop_front()); - } + // If the final index isn't a vector, emit a scalar GEP containing all ops + // and a vector GEP with all zeroes final index. + Value *NewAddr; + if (!Ops[FinalIndex]->getType()->isVectorTy()) { + NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front()); + auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts); + auto *SecondTy = + GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front()); + NewAddr = + Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy)); + } else { + Value *Base = Ops[0]; + Value *Index = Ops[FinalIndex]; - // Now create the GEP with scalar pointer and vector index. - NewAddr = Builder.CreateGEP(SourceTy, Base, Index); + // Create a scalar GEP if there are more than 2 operands. + if (Ops.size() != 2) { + // Replace the last index with 0. + Ops[FinalIndex] = + Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType()); + Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front()); + SourceTy = GetElementPtrInst::getIndexedType(SourceTy, + ArrayRef(Ops).drop_front()); } - } else if (!isa(Ptr)) { - // Not a GEP, maybe its a splat and we can create a GEP to enable - // SelectionDAGBuilder to use it as a uniform base. - Value *V = getSplatValue(Ptr); - if (!V) - return false; - - auto NumElts = cast(Ptr->getType())->getElementCount(); - IRBuilder<> Builder(MemoryInst); - - // Emit a vector GEP with a scalar pointer and all 0s vector index. - Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType()); - auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts); - Type *ScalarTy; - if (cast(MemoryInst)->getIntrinsicID() == - Intrinsic::masked_gather) { - ScalarTy = MemoryInst->getType()->getScalarType(); - } else { - assert(cast(MemoryInst)->getIntrinsicID() == - Intrinsic::masked_scatter); - ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType(); - } - NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy)); - } else { - // Constant, SelectionDAGBuilder knows to check if its a splat. - return false; + // Now create the GEP with scalar pointer and vector index. + NewAddr = Builder.CreateGEP(SourceTy, Base, Index); } MemoryInst->replaceUsesOfWith(Ptr, NewAddr); diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp index 04d6fd5f48cc3..ffe5ddec806a9 100644 --- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp @@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index, assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); - // Handle splat constant pointer. - if (auto *C = dyn_cast(Ptr)) { - C = C->getSplatValue(); - if (!C) - return false; - - Base = SDB->getValue(C); - + // Handle splat (possibly constant) pointer. + if (Value *ScalarV = getSplatValue(Ptr)) { + Base = SDB->getValue(ScalarV); ElementCount NumElts = cast(Ptr->getType())->getElementCount(); EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts); Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT); diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll index 4e6f666fa05de..8dd9039fbce55 100644 --- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll +++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll @@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru ; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0 ; X64-SKX-NEXT: vpmovd2m %xmm0, %k1 ; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1} +; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0), %xmm1 {%k1} ; X64-SKX-NEXT: vmovdqa %xmm1, %xmm0 ; X64-SKX-NEXT: retq ; @@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru ; X86-SKX-NEXT: vpmovd2m %xmm0, %k1 ; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1} +; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0), %xmm1 {%k1} ; X86-SKX-NEXT: vmovdqa %xmm1, %xmm0 ; X86-SKX-NEXT: retl %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0 @@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) { ; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0 ; X64-SKX-NEXT: vpmovd2m %xmm0, %k1 ; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1} +; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0) {%k1} ; X64-SKX-NEXT: retq ; ; X86-SKX-LABEL: splat_ptr_scatter: @@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) { ; X86-SKX-NEXT: vpmovd2m %xmm0, %k1 ; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax ; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0 -; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1} +; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0) {%k1} ; X86-SKX-NEXT: retl %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0 %2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll index 3c5c07f3516c9..6fd4d4a4c6be4 100644 --- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll +++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll @@ -85,7 +85,8 @@ define @global_struct_splat( %mask) #0 { define @splat_ptr_gather(ptr %ptr, %mask, %passthru) #0 { ; CHECK-LABEL: @splat_ptr_gather( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = insertelement poison, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector [[TMP3]], poison, zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.masked.gather.nxv4i32.nxv4p0( [[TMP1]], i32 4, [[MASK:%.*]], [[PASSTHRU:%.*]]) ; CHECK-NEXT: ret [[TMP2]] ; @@ -97,7 +98,8 @@ define @splat_ptr_gather(ptr %ptr, %mask, < define void @splat_ptr_scatter(ptr %ptr, %mask, %val) #0 { ; CHECK-LABEL: @splat_ptr_scatter( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = insertelement poison, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector [[TMP2]], poison, zeroinitializer ; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0( [[VAL:%.*]], [[TMP1]], i32 4, [[MASK:%.*]]) ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll index 36cd69ed01ed9..d1843dcd23863 100644 --- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll +++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll @@ -85,7 +85,8 @@ define @global_struct_splat( %mask) #0 { define @splat_ptr_gather(ptr %ptr, %mask, %passthru) #0 { ; CHECK-LABEL: @splat_ptr_gather( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = insertelement undef, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector [[TMP3]], undef, zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = call @llvm.masked.gather.nxv4i32.nxv4p0( [[TMP1]], i32 4, [[MASK:%.*]], [[PASSTHRU:%.*]]) ; CHECK-NEXT: ret [[TMP2]] ; @@ -97,7 +98,8 @@ define @splat_ptr_gather(ptr %ptr, %mask, < define void @splat_ptr_scatter(ptr %ptr, %mask, %val) #0 { ; CHECK-LABEL: @splat_ptr_scatter( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = insertelement undef, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector [[TMP2]], undef, zeroinitializer ; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0( [[VAL:%.*]], [[TMP1]], i32 4, [[MASK:%.*]]) ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll index 6ef3400812fc8..810a9d7b228c0 100644 --- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll +++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll @@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() { define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) { ; CHECK-LABEL: @splat_ptr_gather( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]]) ; CHECK-NEXT: ret <4 x i32> [[TMP2]] ; @@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) { ; CHECK-LABEL: @splat_ptr_scatter( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer ; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) ; CHECK-NEXT: ret void ; diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll index 8328708393029..cada231707171 100644 --- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll +++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll @@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() { define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) { ; CHECK-LABEL: @splat_ptr_gather( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer ; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]]) ; CHECK-NEXT: ret <4 x i32> [[TMP2]] ; @@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) { ; CHECK-LABEL: @splat_ptr_scatter( -; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer +; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer ; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]]) ; CHECK-NEXT: ret void ;