-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[CFG] Simplify gather/scatter splat pointer matching #145931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version. The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches.
|
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-x86 Author: Philip Reames (preames) ChangesThe primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version. The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches. Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors. Full diff: https://github.com/llvm/llvm-project/pull/145931.diff 7 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..8a0e26bde23ff 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
/// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
-/// only handle a 2 operand GEP in the same basic block or a splat constant
+/// only handle a 2 operand GEP in the same basic block or a canonical splat
/// vector. The 2 operands to the GEP must have a scalar pointer and a vector
/// index.
///
@@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// zero index.
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Value *Ptr) {
- Value *NewAddr;
+ const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!GEP)
+ return false;
- if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
- // Don't optimize GEPs that don't have indices.
- if (!GEP->hasIndices())
- return false;
+ // Don't optimize GEPs that don't have indices.
+ if (!GEP->hasIndices())
+ return false;
- // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
- // FIXME: We should support this by sinking the GEP.
- if (MemoryInst->getParent() != GEP->getParent())
- return false;
+ // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+ // FIXME: We should support this by sinking the GEP.
+ if (MemoryInst->getParent() != GEP->getParent())
+ return false;
- SmallVector<Value *, 2> Ops(GEP->operands());
+ SmallVector<Value *, 2> Ops(GEP->operands());
- bool RewriteGEP = false;
+ bool RewriteGEP = false;
- if (Ops[0]->getType()->isVectorTy()) {
- Ops[0] = getSplatValue(Ops[0]);
- if (!Ops[0])
- return false;
- RewriteGEP = true;
- }
+ if (Ops[0]->getType()->isVectorTy()) {
+ Ops[0] = getSplatValue(Ops[0]);
+ if (!Ops[0])
+ return false;
+ RewriteGEP = true;
+ }
- unsigned FinalIndex = Ops.size() - 1;
+ unsigned FinalIndex = Ops.size() - 1;
- // Ensure all but the last index is 0.
- // FIXME: This isn't strictly required. All that's required is that they are
- // all scalars or splats.
- for (unsigned i = 1; i < FinalIndex; ++i) {
- auto *C = dyn_cast<Constant>(Ops[i]);
- if (!C)
- return false;
- if (isa<VectorType>(C->getType()))
- C = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(C);
- if (!CI || !CI->isZero())
- return false;
- // Scalarize the index if needed.
- Ops[i] = CI;
- }
-
- // Try to scalarize the final index.
- if (Ops[FinalIndex]->getType()->isVectorTy()) {
- if (Value *V = getSplatValue(Ops[FinalIndex])) {
- auto *C = dyn_cast<ConstantInt>(V);
- // Don't scalarize all zeros vector.
- if (!C || !C->isZero()) {
- Ops[FinalIndex] = V;
- RewriteGEP = true;
- }
+ // Ensure all but the last index is 0.
+ // FIXME: This isn't strictly required. All that's required is that they are
+ // all scalars or splats.
+ for (unsigned i = 1; i < FinalIndex; ++i) {
+ auto *C = dyn_cast<Constant>(Ops[i]);
+ if (!C)
+ return false;
+ if (isa<VectorType>(C->getType()))
+ C = C->getSplatValue();
+ auto *CI = dyn_cast_or_null<ConstantInt>(C);
+ if (!CI || !CI->isZero())
+ return false;
+ // Scalarize the index if needed.
+ Ops[i] = CI;
+ }
+
+ // Try to scalarize the final index.
+ if (Ops[FinalIndex]->getType()->isVectorTy()) {
+ if (Value *V = getSplatValue(Ops[FinalIndex])) {
+ auto *C = dyn_cast<ConstantInt>(V);
+ // Don't scalarize all zeros vector.
+ if (!C || !C->isZero()) {
+ Ops[FinalIndex] = V;
+ RewriteGEP = true;
}
}
+ }
- // If we made any changes or the we have extra operands, we need to generate
- // new instructions.
- if (!RewriteGEP && Ops.size() == 2)
- return false;
+ // If we made any changes or the we have extra operands, we need to generate
+ // new instructions.
+ if (!RewriteGEP && Ops.size() == 2)
+ return false;
- auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
+ auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
- IRBuilder<> Builder(MemoryInst);
+ IRBuilder<> Builder(MemoryInst);
- Type *SourceTy = GEP->getSourceElementType();
- Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+ Type *SourceTy = GEP->getSourceElementType();
+ Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
- // If the final index isn't a vector, emit a scalar GEP containing all ops
- // and a vector GEP with all zeroes final index.
- if (!Ops[FinalIndex]->getType()->isVectorTy()) {
- NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
- auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
- auto *SecondTy = GetElementPtrInst::getIndexedType(
- SourceTy, ArrayRef(Ops).drop_front());
- NewAddr =
- Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
- } else {
- Value *Base = Ops[0];
- Value *Index = Ops[FinalIndex];
-
- // Create a scalar GEP if there are more than 2 operands.
- if (Ops.size() != 2) {
- // Replace the last index with 0.
- Ops[FinalIndex] =
- Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
- Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
- SourceTy = GetElementPtrInst::getIndexedType(
- SourceTy, ArrayRef(Ops).drop_front());
- }
+ // If the final index isn't a vector, emit a scalar GEP containing all ops
+ // and a vector GEP with all zeroes final index.
+ Value *NewAddr;
+ if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+ NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+ auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
+ auto *SecondTy =
+ GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front());
+ NewAddr =
+ Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
+ } else {
+ Value *Base = Ops[0];
+ Value *Index = Ops[FinalIndex];
- // Now create the GEP with scalar pointer and vector index.
- NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
+ // Create a scalar GEP if there are more than 2 operands.
+ if (Ops.size() != 2) {
+ // Replace the last index with 0.
+ Ops[FinalIndex] =
+ Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
+ Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
+ SourceTy = GetElementPtrInst::getIndexedType(SourceTy,
+ ArrayRef(Ops).drop_front());
}
- } else if (!isa<Constant>(Ptr)) {
- // Not a GEP, maybe its a splat and we can create a GEP to enable
- // SelectionDAGBuilder to use it as a uniform base.
- Value *V = getSplatValue(Ptr);
- if (!V)
- return false;
-
- auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
- IRBuilder<> Builder(MemoryInst);
-
- // Emit a vector GEP with a scalar pointer and all 0s vector index.
- Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
- auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
- Type *ScalarTy;
- if (cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
- Intrinsic::masked_gather) {
- ScalarTy = MemoryInst->getType()->getScalarType();
- } else {
- assert(cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
- Intrinsic::masked_scatter);
- ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType();
- }
- NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy));
- } else {
- // Constant, SelectionDAGBuilder knows to check if its a splat.
- return false;
+ // Now create the GEP with scalar pointer and vector index.
+ NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
}
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 04d6fd5f48cc3..ffe5ddec806a9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
- // Handle splat constant pointer.
- if (auto *C = dyn_cast<Constant>(Ptr)) {
- C = C->getSplatValue();
- if (!C)
- return false;
-
- Base = SDB->getValue(C);
-
+ // Handle splat (possibly constant) pointer.
+ if (Value *ScalarV = getSplatValue(Ptr)) {
+ Base = SDB->getValue(ScalarV);
ElementCount NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 4e6f666fa05de..8dd9039fbce55 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
+; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0), %xmm1 {%k1}
; X64-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X64-SKX-NEXT: retq
;
@@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
+; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0), %xmm1 {%k1}
; X86-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
@@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
+; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0) {%k1}
; X64-SKX-NEXT: retq
;
; X86-SKX-LABEL: splat_ptr_scatter:
@@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
+; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0) {%k1}
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
%2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
index 3c5c07f3516c9..6fd4d4a4c6be4 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
index 36cd69ed01ed9..d1843dcd23863 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index 6ef3400812fc8..810a9d7b228c0 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
@@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index 8328708393029..cada231707171 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
@@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
|
|
@llvm/pr-subscribers-llvm-selectiondag Author: Philip Reames (preames) ChangesThe primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version. The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches. Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors. Full diff: https://github.com/llvm/llvm-project/pull/145931.diff 7 Files Affected:
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..8a0e26bde23ff 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
/// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
-/// only handle a 2 operand GEP in the same basic block or a splat constant
+/// only handle a 2 operand GEP in the same basic block or a canonical splat
/// vector. The 2 operands to the GEP must have a scalar pointer and a vector
/// index.
///
@@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// zero index.
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Value *Ptr) {
- Value *NewAddr;
+ const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+ if (!GEP)
+ return false;
- if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
- // Don't optimize GEPs that don't have indices.
- if (!GEP->hasIndices())
- return false;
+ // Don't optimize GEPs that don't have indices.
+ if (!GEP->hasIndices())
+ return false;
- // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
- // FIXME: We should support this by sinking the GEP.
- if (MemoryInst->getParent() != GEP->getParent())
- return false;
+ // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+ // FIXME: We should support this by sinking the GEP.
+ if (MemoryInst->getParent() != GEP->getParent())
+ return false;
- SmallVector<Value *, 2> Ops(GEP->operands());
+ SmallVector<Value *, 2> Ops(GEP->operands());
- bool RewriteGEP = false;
+ bool RewriteGEP = false;
- if (Ops[0]->getType()->isVectorTy()) {
- Ops[0] = getSplatValue(Ops[0]);
- if (!Ops[0])
- return false;
- RewriteGEP = true;
- }
+ if (Ops[0]->getType()->isVectorTy()) {
+ Ops[0] = getSplatValue(Ops[0]);
+ if (!Ops[0])
+ return false;
+ RewriteGEP = true;
+ }
- unsigned FinalIndex = Ops.size() - 1;
+ unsigned FinalIndex = Ops.size() - 1;
- // Ensure all but the last index is 0.
- // FIXME: This isn't strictly required. All that's required is that they are
- // all scalars or splats.
- for (unsigned i = 1; i < FinalIndex; ++i) {
- auto *C = dyn_cast<Constant>(Ops[i]);
- if (!C)
- return false;
- if (isa<VectorType>(C->getType()))
- C = C->getSplatValue();
- auto *CI = dyn_cast_or_null<ConstantInt>(C);
- if (!CI || !CI->isZero())
- return false;
- // Scalarize the index if needed.
- Ops[i] = CI;
- }
-
- // Try to scalarize the final index.
- if (Ops[FinalIndex]->getType()->isVectorTy()) {
- if (Value *V = getSplatValue(Ops[FinalIndex])) {
- auto *C = dyn_cast<ConstantInt>(V);
- // Don't scalarize all zeros vector.
- if (!C || !C->isZero()) {
- Ops[FinalIndex] = V;
- RewriteGEP = true;
- }
+ // Ensure all but the last index is 0.
+ // FIXME: This isn't strictly required. All that's required is that they are
+ // all scalars or splats.
+ for (unsigned i = 1; i < FinalIndex; ++i) {
+ auto *C = dyn_cast<Constant>(Ops[i]);
+ if (!C)
+ return false;
+ if (isa<VectorType>(C->getType()))
+ C = C->getSplatValue();
+ auto *CI = dyn_cast_or_null<ConstantInt>(C);
+ if (!CI || !CI->isZero())
+ return false;
+ // Scalarize the index if needed.
+ Ops[i] = CI;
+ }
+
+ // Try to scalarize the final index.
+ if (Ops[FinalIndex]->getType()->isVectorTy()) {
+ if (Value *V = getSplatValue(Ops[FinalIndex])) {
+ auto *C = dyn_cast<ConstantInt>(V);
+ // Don't scalarize all zeros vector.
+ if (!C || !C->isZero()) {
+ Ops[FinalIndex] = V;
+ RewriteGEP = true;
}
}
+ }
- // If we made any changes or the we have extra operands, we need to generate
- // new instructions.
- if (!RewriteGEP && Ops.size() == 2)
- return false;
+ // If we made any changes or the we have extra operands, we need to generate
+ // new instructions.
+ if (!RewriteGEP && Ops.size() == 2)
+ return false;
- auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
+ auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
- IRBuilder<> Builder(MemoryInst);
+ IRBuilder<> Builder(MemoryInst);
- Type *SourceTy = GEP->getSourceElementType();
- Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+ Type *SourceTy = GEP->getSourceElementType();
+ Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
- // If the final index isn't a vector, emit a scalar GEP containing all ops
- // and a vector GEP with all zeroes final index.
- if (!Ops[FinalIndex]->getType()->isVectorTy()) {
- NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
- auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
- auto *SecondTy = GetElementPtrInst::getIndexedType(
- SourceTy, ArrayRef(Ops).drop_front());
- NewAddr =
- Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
- } else {
- Value *Base = Ops[0];
- Value *Index = Ops[FinalIndex];
-
- // Create a scalar GEP if there are more than 2 operands.
- if (Ops.size() != 2) {
- // Replace the last index with 0.
- Ops[FinalIndex] =
- Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
- Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
- SourceTy = GetElementPtrInst::getIndexedType(
- SourceTy, ArrayRef(Ops).drop_front());
- }
+ // If the final index isn't a vector, emit a scalar GEP containing all ops
+ // and a vector GEP with all zeroes final index.
+ Value *NewAddr;
+ if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+ NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+ auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
+ auto *SecondTy =
+ GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front());
+ NewAddr =
+ Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
+ } else {
+ Value *Base = Ops[0];
+ Value *Index = Ops[FinalIndex];
- // Now create the GEP with scalar pointer and vector index.
- NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
+ // Create a scalar GEP if there are more than 2 operands.
+ if (Ops.size() != 2) {
+ // Replace the last index with 0.
+ Ops[FinalIndex] =
+ Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
+ Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
+ SourceTy = GetElementPtrInst::getIndexedType(SourceTy,
+ ArrayRef(Ops).drop_front());
}
- } else if (!isa<Constant>(Ptr)) {
- // Not a GEP, maybe its a splat and we can create a GEP to enable
- // SelectionDAGBuilder to use it as a uniform base.
- Value *V = getSplatValue(Ptr);
- if (!V)
- return false;
-
- auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
- IRBuilder<> Builder(MemoryInst);
-
- // Emit a vector GEP with a scalar pointer and all 0s vector index.
- Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
- auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
- Type *ScalarTy;
- if (cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
- Intrinsic::masked_gather) {
- ScalarTy = MemoryInst->getType()->getScalarType();
- } else {
- assert(cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
- Intrinsic::masked_scatter);
- ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType();
- }
- NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy));
- } else {
- // Constant, SelectionDAGBuilder knows to check if its a splat.
- return false;
+ // Now create the GEP with scalar pointer and vector index.
+ NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
}
MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 04d6fd5f48cc3..ffe5ddec806a9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
- // Handle splat constant pointer.
- if (auto *C = dyn_cast<Constant>(Ptr)) {
- C = C->getSplatValue();
- if (!C)
- return false;
-
- Base = SDB->getValue(C);
-
+ // Handle splat (possibly constant) pointer.
+ if (Value *ScalarV = getSplatValue(Ptr)) {
+ Base = SDB->getValue(ScalarV);
ElementCount NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 4e6f666fa05de..8dd9039fbce55 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
+; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0), %xmm1 {%k1}
; X64-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X64-SKX-NEXT: retq
;
@@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
+; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0), %xmm1 {%k1}
; X86-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
@@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
+; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0) {%k1}
; X64-SKX-NEXT: retq
;
; X86-SKX-LABEL: splat_ptr_scatter:
@@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
+; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0) {%k1}
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
%2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
index 3c5c07f3516c9..6fd4d4a4c6be4 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
index 36cd69ed01ed9..d1843dcd23863 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index 6ef3400812fc8..810a9d7b228c0 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
@@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index 8328708393029..cada231707171 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() {
define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
@@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
|
You can test this locally with the following command:git diff -U0 --pickaxe-regex -S '([^a-zA-Z0-9#_-]undef[^a-zA-Z0-9_-]|UndefValue::get)' 'HEAD~1' HEAD llvm/lib/CodeGen/CodeGenPrepare.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/test/CodeGen/X86/masked_gather_scatter.ll llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.llThe following files introduce new uses of undef:
Undef is now deprecated and should only be used in the rare cases where no replacement is possible. For example, a load of uninitialized memory yields In tests, avoid using For example, this is considered a bad practice: define void @fn() {
...
br i1 undef, ...
}Please use the following instead: define void @fn(i1 %cond) {
...
br i1 %cond, ...
}Please refer to the Undefined Behavior Manual for more information. |
| Base = SDB->getValue(C); | ||
|
|
||
| // Handle splat (possibly constant) pointer. | ||
| if (Value *ScalarV = getSplatValue(Ptr)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work if the splat instruction isn't in the same basic block? You have to be careful looking through instructions in SelectinDAGBuilder because the Value won't be exported from the producing basic block.
There used to be code to find the export that I removed in 944cc5e
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh yuck, I should have checked the history here more closely. No, it doesn't.
I'm going to close this and revisit from scratch. I can probably find an analogous approach, but TBD.
The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version.
The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches.
Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors.