Skip to content

Conversation

@preames
Copy link
Collaborator

@preames preames commented Jun 26, 2025

The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version.

The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches.

Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors.

The primary goal of this change is to simplify code, but it also
ends up being slightly more powerful.  Rather than repeating the
gather/scatter of splat logic in both CGP and SDAG, generalize the
SDAG copy slightly and delete the CGP version.

The X86 codegen diffs are improvements - we were scaling a zero
value by 4, whereas now we're not scaling it.  This codegen can
likely be further improved, but that'll be in upcoming patches.
@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-x86

Author: Philip Reames (preames)

Changes

The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version.

The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches.

Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors.


Full diff: https://github.com/llvm/llvm-project/pull/145931.diff

7 Files Affected:

  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+76-102)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+3-8)
  • (modified) llvm/test/CodeGen/X86/masked_gather_scatter.ll (+4-4)
  • (modified) llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll (+4-2)
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..8a0e26bde23ff 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 
 /// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
 /// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
-/// only handle a 2 operand GEP in the same basic block or a splat constant
+/// only handle a 2 operand GEP in the same basic block or a canonical splat
 /// vector. The 2 operands to the GEP must have a scalar pointer and a vector
 /// index.
 ///
@@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 /// zero index.
 bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
                                                Value *Ptr) {
-  Value *NewAddr;
+  const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!GEP)
+    return false;
 
-  if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
-    // Don't optimize GEPs that don't have indices.
-    if (!GEP->hasIndices())
-      return false;
+  // Don't optimize GEPs that don't have indices.
+  if (!GEP->hasIndices())
+    return false;
 
-    // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
-    // FIXME: We should support this by sinking the GEP.
-    if (MemoryInst->getParent() != GEP->getParent())
-      return false;
+  // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+  // FIXME: We should support this by sinking the GEP.
+  if (MemoryInst->getParent() != GEP->getParent())
+    return false;
 
-    SmallVector<Value *, 2> Ops(GEP->operands());
+  SmallVector<Value *, 2> Ops(GEP->operands());
 
-    bool RewriteGEP = false;
+  bool RewriteGEP = false;
 
-    if (Ops[0]->getType()->isVectorTy()) {
-      Ops[0] = getSplatValue(Ops[0]);
-      if (!Ops[0])
-        return false;
-      RewriteGEP = true;
-    }
+  if (Ops[0]->getType()->isVectorTy()) {
+    Ops[0] = getSplatValue(Ops[0]);
+    if (!Ops[0])
+      return false;
+    RewriteGEP = true;
+  }
 
-    unsigned FinalIndex = Ops.size() - 1;
+  unsigned FinalIndex = Ops.size() - 1;
 
-    // Ensure all but the last index is 0.
-    // FIXME: This isn't strictly required. All that's required is that they are
-    // all scalars or splats.
-    for (unsigned i = 1; i < FinalIndex; ++i) {
-      auto *C = dyn_cast<Constant>(Ops[i]);
-      if (!C)
-        return false;
-      if (isa<VectorType>(C->getType()))
-        C = C->getSplatValue();
-      auto *CI = dyn_cast_or_null<ConstantInt>(C);
-      if (!CI || !CI->isZero())
-        return false;
-      // Scalarize the index if needed.
-      Ops[i] = CI;
-    }
-
-    // Try to scalarize the final index.
-    if (Ops[FinalIndex]->getType()->isVectorTy()) {
-      if (Value *V = getSplatValue(Ops[FinalIndex])) {
-        auto *C = dyn_cast<ConstantInt>(V);
-        // Don't scalarize all zeros vector.
-        if (!C || !C->isZero()) {
-          Ops[FinalIndex] = V;
-          RewriteGEP = true;
-        }
+  // Ensure all but the last index is 0.
+  // FIXME: This isn't strictly required. All that's required is that they are
+  // all scalars or splats.
+  for (unsigned i = 1; i < FinalIndex; ++i) {
+    auto *C = dyn_cast<Constant>(Ops[i]);
+    if (!C)
+      return false;
+    if (isa<VectorType>(C->getType()))
+      C = C->getSplatValue();
+    auto *CI = dyn_cast_or_null<ConstantInt>(C);
+    if (!CI || !CI->isZero())
+      return false;
+    // Scalarize the index if needed.
+    Ops[i] = CI;
+  }
+
+  // Try to scalarize the final index.
+  if (Ops[FinalIndex]->getType()->isVectorTy()) {
+    if (Value *V = getSplatValue(Ops[FinalIndex])) {
+      auto *C = dyn_cast<ConstantInt>(V);
+      // Don't scalarize all zeros vector.
+      if (!C || !C->isZero()) {
+        Ops[FinalIndex] = V;
+        RewriteGEP = true;
       }
     }
+  }
 
-    // If we made any changes or the we have extra operands, we need to generate
-    // new instructions.
-    if (!RewriteGEP && Ops.size() == 2)
-      return false;
+  // If we made any changes or the we have extra operands, we need to generate
+  // new instructions.
+  if (!RewriteGEP && Ops.size() == 2)
+    return false;
 
-    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
+  auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
-    IRBuilder<> Builder(MemoryInst);
+  IRBuilder<> Builder(MemoryInst);
 
-    Type *SourceTy = GEP->getSourceElementType();
-    Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+  Type *SourceTy = GEP->getSourceElementType();
+  Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
 
-    // If the final index isn't a vector, emit a scalar GEP containing all ops
-    // and a vector GEP with all zeroes final index.
-    if (!Ops[FinalIndex]->getType()->isVectorTy()) {
-      NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
-      auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
-      auto *SecondTy = GetElementPtrInst::getIndexedType(
-          SourceTy, ArrayRef(Ops).drop_front());
-      NewAddr =
-          Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
-    } else {
-      Value *Base = Ops[0];
-      Value *Index = Ops[FinalIndex];
-
-      // Create a scalar GEP if there are more than 2 operands.
-      if (Ops.size() != 2) {
-        // Replace the last index with 0.
-        Ops[FinalIndex] =
-            Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
-        Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
-        SourceTy = GetElementPtrInst::getIndexedType(
-            SourceTy, ArrayRef(Ops).drop_front());
-      }
+  // If the final index isn't a vector, emit a scalar GEP containing all ops
+  // and a vector GEP with all zeroes final index.
+  Value *NewAddr;
+  if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+    NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+    auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
+    auto *SecondTy =
+        GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front());
+    NewAddr =
+        Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
+  } else {
+    Value *Base = Ops[0];
+    Value *Index = Ops[FinalIndex];
 
-      // Now create the GEP with scalar pointer and vector index.
-      NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
+    // Create a scalar GEP if there are more than 2 operands.
+    if (Ops.size() != 2) {
+      // Replace the last index with 0.
+      Ops[FinalIndex] =
+          Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
+      Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
+      SourceTy = GetElementPtrInst::getIndexedType(SourceTy,
+                                                   ArrayRef(Ops).drop_front());
     }
-  } else if (!isa<Constant>(Ptr)) {
-    // Not a GEP, maybe its a splat and we can create a GEP to enable
-    // SelectionDAGBuilder to use it as a uniform base.
-    Value *V = getSplatValue(Ptr);
-    if (!V)
-      return false;
-
-    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
-    IRBuilder<> Builder(MemoryInst);
-
-    // Emit a vector GEP with a scalar pointer and all 0s vector index.
-    Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
-    auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
-    Type *ScalarTy;
-    if (cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
-        Intrinsic::masked_gather) {
-      ScalarTy = MemoryInst->getType()->getScalarType();
-    } else {
-      assert(cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
-             Intrinsic::masked_scatter);
-      ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType();
-    }
-    NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy));
-  } else {
-    // Constant, SelectionDAGBuilder knows to check if its a splat.
-    return false;
+    // Now create the GEP with scalar pointer and vector index.
+    NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
   }
 
   MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 04d6fd5f48cc3..ffe5ddec806a9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
 
   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
 
-  // Handle splat constant pointer.
-  if (auto *C = dyn_cast<Constant>(Ptr)) {
-    C = C->getSplatValue();
-    if (!C)
-      return false;
-
-    Base = SDB->getValue(C);
-
+  // Handle splat (possibly constant) pointer.
+  if (Value *ScalarV = getSplatValue(Ptr)) {
+    Base = SDB->getValue(ScalarV);
     ElementCount NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
     EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
     Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 4e6f666fa05de..8dd9039fbce55 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 ; X64-SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; X64-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X64-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT:    vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
+; X64-SKX-NEXT:    vpgatherdd (%rdi,%xmm0), %xmm1 {%k1}
 ; X64-SKX-NEXT:    vmovdqa %xmm1, %xmm0
 ; X64-SKX-NEXT:    retq
 ;
@@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 ; X86-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT:    vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
+; X86-SKX-NEXT:    vpgatherdd (%eax,%xmm0), %xmm1 {%k1}
 ; X86-SKX-NEXT:    vmovdqa %xmm1, %xmm0
 ; X86-SKX-NEXT:    retl
   %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
@@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; X64-SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; X64-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X64-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT:    vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
+; X64-SKX-NEXT:    vpscatterdd %xmm1, (%rdi,%xmm0) {%k1}
 ; X64-SKX-NEXT:    retq
 ;
 ; X86-SKX-LABEL: splat_ptr_scatter:
@@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; X86-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT:    vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
+; X86-SKX-NEXT:    vpscatterdd %xmm1, (%eax,%xmm0) {%k1}
 ; X86-SKX-NEXT:    retl
   %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
   %2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
index 3c5c07f3516c9..6fd4d4a4c6be4 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
 
 define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP2]]
 ;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
 
 define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
index 36cd69ed01ed9..d1843dcd23863 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
 
 define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP2]]
 ;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
 
 define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index 6ef3400812fc8..810a9d7b228c0 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() {
 
 define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
@@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 
 define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index 8328708393029..cada231707171 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() {
 
 define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
@@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 
 define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;

@llvmbot
Copy link
Member

llvmbot commented Jun 26, 2025

@llvm/pr-subscribers-llvm-selectiondag

Author: Philip Reames (preames)

Changes

The primary goal of this change is to simplify code, but it also ends up being slightly more powerful. Rather than repeating the gather/scatter of splat logic in both CGP and SDAG, generalize the SDAG copy slightly and delete the CGP version.

The X86 codegen diffs are improvements - we were scaling a zero value by 4, whereas now we're not scaling it. This codegen can likely be further improved, but that'll be in upcoming patches.

Note that we have instcombine rules to convert unmasked gather/scatter to scalar load/store respectively, so this really only matters codegen quality wise for the masked flavors.


Full diff: https://github.com/llvm/llvm-project/pull/145931.diff

7 Files Affected:

  • (modified) llvm/lib/CodeGen/CodeGenPrepare.cpp (+76-102)
  • (modified) llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp (+3-8)
  • (modified) llvm/test/CodeGen/X86/masked_gather_scatter.ll (+4-4)
  • (modified) llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll (+4-2)
  • (modified) llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll (+4-2)
diff --git a/llvm/lib/CodeGen/CodeGenPrepare.cpp b/llvm/lib/CodeGen/CodeGenPrepare.cpp
index 43574a54c37dd..8a0e26bde23ff 100644
--- a/llvm/lib/CodeGen/CodeGenPrepare.cpp
+++ b/llvm/lib/CodeGen/CodeGenPrepare.cpp
@@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 
 /// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
 /// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
-/// only handle a 2 operand GEP in the same basic block or a splat constant
+/// only handle a 2 operand GEP in the same basic block or a canonical splat
 /// vector. The 2 operands to the GEP must have a scalar pointer and a vector
 /// index.
 ///
@@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
 /// zero index.
 bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
                                                Value *Ptr) {
-  Value *NewAddr;
+  const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
+  if (!GEP)
+    return false;
 
-  if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
-    // Don't optimize GEPs that don't have indices.
-    if (!GEP->hasIndices())
-      return false;
+  // Don't optimize GEPs that don't have indices.
+  if (!GEP->hasIndices())
+    return false;
 
-    // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
-    // FIXME: We should support this by sinking the GEP.
-    if (MemoryInst->getParent() != GEP->getParent())
-      return false;
+  // If the GEP and the gather/scatter aren't in the same BB, don't optimize.
+  // FIXME: We should support this by sinking the GEP.
+  if (MemoryInst->getParent() != GEP->getParent())
+    return false;
 
-    SmallVector<Value *, 2> Ops(GEP->operands());
+  SmallVector<Value *, 2> Ops(GEP->operands());
 
-    bool RewriteGEP = false;
+  bool RewriteGEP = false;
 
-    if (Ops[0]->getType()->isVectorTy()) {
-      Ops[0] = getSplatValue(Ops[0]);
-      if (!Ops[0])
-        return false;
-      RewriteGEP = true;
-    }
+  if (Ops[0]->getType()->isVectorTy()) {
+    Ops[0] = getSplatValue(Ops[0]);
+    if (!Ops[0])
+      return false;
+    RewriteGEP = true;
+  }
 
-    unsigned FinalIndex = Ops.size() - 1;
+  unsigned FinalIndex = Ops.size() - 1;
 
-    // Ensure all but the last index is 0.
-    // FIXME: This isn't strictly required. All that's required is that they are
-    // all scalars or splats.
-    for (unsigned i = 1; i < FinalIndex; ++i) {
-      auto *C = dyn_cast<Constant>(Ops[i]);
-      if (!C)
-        return false;
-      if (isa<VectorType>(C->getType()))
-        C = C->getSplatValue();
-      auto *CI = dyn_cast_or_null<ConstantInt>(C);
-      if (!CI || !CI->isZero())
-        return false;
-      // Scalarize the index if needed.
-      Ops[i] = CI;
-    }
-
-    // Try to scalarize the final index.
-    if (Ops[FinalIndex]->getType()->isVectorTy()) {
-      if (Value *V = getSplatValue(Ops[FinalIndex])) {
-        auto *C = dyn_cast<ConstantInt>(V);
-        // Don't scalarize all zeros vector.
-        if (!C || !C->isZero()) {
-          Ops[FinalIndex] = V;
-          RewriteGEP = true;
-        }
+  // Ensure all but the last index is 0.
+  // FIXME: This isn't strictly required. All that's required is that they are
+  // all scalars or splats.
+  for (unsigned i = 1; i < FinalIndex; ++i) {
+    auto *C = dyn_cast<Constant>(Ops[i]);
+    if (!C)
+      return false;
+    if (isa<VectorType>(C->getType()))
+      C = C->getSplatValue();
+    auto *CI = dyn_cast_or_null<ConstantInt>(C);
+    if (!CI || !CI->isZero())
+      return false;
+    // Scalarize the index if needed.
+    Ops[i] = CI;
+  }
+
+  // Try to scalarize the final index.
+  if (Ops[FinalIndex]->getType()->isVectorTy()) {
+    if (Value *V = getSplatValue(Ops[FinalIndex])) {
+      auto *C = dyn_cast<ConstantInt>(V);
+      // Don't scalarize all zeros vector.
+      if (!C || !C->isZero()) {
+        Ops[FinalIndex] = V;
+        RewriteGEP = true;
       }
     }
+  }
 
-    // If we made any changes or the we have extra operands, we need to generate
-    // new instructions.
-    if (!RewriteGEP && Ops.size() == 2)
-      return false;
+  // If we made any changes or the we have extra operands, we need to generate
+  // new instructions.
+  if (!RewriteGEP && Ops.size() == 2)
+    return false;
 
-    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
+  auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
-    IRBuilder<> Builder(MemoryInst);
+  IRBuilder<> Builder(MemoryInst);
 
-    Type *SourceTy = GEP->getSourceElementType();
-    Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
+  Type *SourceTy = GEP->getSourceElementType();
+  Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
 
-    // If the final index isn't a vector, emit a scalar GEP containing all ops
-    // and a vector GEP with all zeroes final index.
-    if (!Ops[FinalIndex]->getType()->isVectorTy()) {
-      NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
-      auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
-      auto *SecondTy = GetElementPtrInst::getIndexedType(
-          SourceTy, ArrayRef(Ops).drop_front());
-      NewAddr =
-          Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
-    } else {
-      Value *Base = Ops[0];
-      Value *Index = Ops[FinalIndex];
-
-      // Create a scalar GEP if there are more than 2 operands.
-      if (Ops.size() != 2) {
-        // Replace the last index with 0.
-        Ops[FinalIndex] =
-            Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
-        Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
-        SourceTy = GetElementPtrInst::getIndexedType(
-            SourceTy, ArrayRef(Ops).drop_front());
-      }
+  // If the final index isn't a vector, emit a scalar GEP containing all ops
+  // and a vector GEP with all zeroes final index.
+  Value *NewAddr;
+  if (!Ops[FinalIndex]->getType()->isVectorTy()) {
+    NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
+    auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
+    auto *SecondTy =
+        GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front());
+    NewAddr =
+        Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
+  } else {
+    Value *Base = Ops[0];
+    Value *Index = Ops[FinalIndex];
 
-      // Now create the GEP with scalar pointer and vector index.
-      NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
+    // Create a scalar GEP if there are more than 2 operands.
+    if (Ops.size() != 2) {
+      // Replace the last index with 0.
+      Ops[FinalIndex] =
+          Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
+      Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
+      SourceTy = GetElementPtrInst::getIndexedType(SourceTy,
+                                                   ArrayRef(Ops).drop_front());
     }
-  } else if (!isa<Constant>(Ptr)) {
-    // Not a GEP, maybe its a splat and we can create a GEP to enable
-    // SelectionDAGBuilder to use it as a uniform base.
-    Value *V = getSplatValue(Ptr);
-    if (!V)
-      return false;
-
-    auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
 
-    IRBuilder<> Builder(MemoryInst);
-
-    // Emit a vector GEP with a scalar pointer and all 0s vector index.
-    Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
-    auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
-    Type *ScalarTy;
-    if (cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
-        Intrinsic::masked_gather) {
-      ScalarTy = MemoryInst->getType()->getScalarType();
-    } else {
-      assert(cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
-             Intrinsic::masked_scatter);
-      ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType();
-    }
-    NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy));
-  } else {
-    // Constant, SelectionDAGBuilder knows to check if its a splat.
-    return false;
+    // Now create the GEP with scalar pointer and vector index.
+    NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
   }
 
   MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
diff --git a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
index 04d6fd5f48cc3..ffe5ddec806a9 100644
--- a/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
+++ b/llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
@@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,
 
   assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");
 
-  // Handle splat constant pointer.
-  if (auto *C = dyn_cast<Constant>(Ptr)) {
-    C = C->getSplatValue();
-    if (!C)
-      return false;
-
-    Base = SDB->getValue(C);
-
+  // Handle splat (possibly constant) pointer.
+  if (Value *ScalarV = getSplatValue(Ptr)) {
+    Base = SDB->getValue(ScalarV);
     ElementCount NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
     EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
     Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
diff --git a/llvm/test/CodeGen/X86/masked_gather_scatter.ll b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
index 4e6f666fa05de..8dd9039fbce55 100644
--- a/llvm/test/CodeGen/X86/masked_gather_scatter.ll
+++ b/llvm/test/CodeGen/X86/masked_gather_scatter.ll
@@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 ; X64-SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; X64-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X64-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT:    vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
+; X64-SKX-NEXT:    vpgatherdd (%rdi,%xmm0), %xmm1 {%k1}
 ; X64-SKX-NEXT:    vmovdqa %xmm1, %xmm0
 ; X64-SKX-NEXT:    retq
 ;
@@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 ; X86-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT:    vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
+; X86-SKX-NEXT:    vpgatherdd (%eax,%xmm0), %xmm1 {%k1}
 ; X86-SKX-NEXT:    vmovdqa %xmm1, %xmm0
 ; X86-SKX-NEXT:    retl
   %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
@@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; X64-SKX-NEXT:    vpslld $31, %xmm0, %xmm0
 ; X64-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X64-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X64-SKX-NEXT:    vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
+; X64-SKX-NEXT:    vpscatterdd %xmm1, (%rdi,%xmm0) {%k1}
 ; X64-SKX-NEXT:    retq
 ;
 ; X86-SKX-LABEL: splat_ptr_scatter:
@@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; X86-SKX-NEXT:    vpmovd2m %xmm0, %k1
 ; X86-SKX-NEXT:    movl {{[0-9]+}}(%esp), %eax
 ; X86-SKX-NEXT:    vpxor %xmm0, %xmm0, %xmm0
-; X86-SKX-NEXT:    vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
+; X86-SKX-NEXT:    vpscatterdd %xmm1, (%eax,%xmm0) {%k1}
 ; X86-SKX-NEXT:    retl
   %1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
   %2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
index 3c5c07f3516c9..6fd4d4a4c6be4 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
 
 define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP2]]
 ;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
 
 define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
index 36cd69ed01ed9..d1843dcd23863 100644
--- a/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
@@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {
 
 define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <vscale x 4 x i32> [[TMP2]]
 ;
@@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <
 
 define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
index 6ef3400812fc8..810a9d7b228c0 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll
@@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() {
 
 define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
@@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 
 define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;
diff --git a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
index 8328708393029..cada231707171 100644
--- a/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
+++ b/llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
@@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() {
 
 define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
 ; CHECK-LABEL: @splat_ptr_gather(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
 ; CHECK-NEXT:    ret <4 x i32> [[TMP2]]
 ;
@@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
 
 define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
 ; CHECK-LABEL: @splat_ptr_scatter(
-; CHECK-NEXT:    [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
+; CHECK-NEXT:    [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
+; CHECK-NEXT:    [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer
 ; CHECK-NEXT:    call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
 ; CHECK-NEXT:    ret void
 ;

@github-actions
Copy link

⚠️ undef deprecator found issues in your code. ⚠️

You can test this locally with the following command:
git diff -U0 --pickaxe-regex -S '([^a-zA-Z0-9#_-]undef[^a-zA-Z0-9_-]|UndefValue::get)' 'HEAD~1' HEAD llvm/lib/CodeGen/CodeGenPrepare.cpp llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp llvm/test/CodeGen/X86/masked_gather_scatter.ll llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt-inseltpoison.ll llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt-inseltpoison.ll llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll

The following files introduce new uses of undef:

  • llvm/test/Transforms/CodeGenPrepare/AArch64/gather-scatter-opt.ll
  • llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll

Undef is now deprecated and should only be used in the rare cases where no replacement is possible. For example, a load of uninitialized memory yields undef. You should use poison values for placeholders instead.

In tests, avoid using undef and having tests that trigger undefined behavior. If you need an operand with some unimportant value, you can add a new argument to the function and use that instead.

For example, this is considered a bad practice:

define void @fn() {
  ...
  br i1 undef, ...
}

Please use the following instead:

define void @fn(i1 %cond) {
  ...
  br i1 %cond, ...
}

Please refer to the Undefined Behavior Manual for more information.

Base = SDB->getValue(C);

// Handle splat (possibly constant) pointer.
if (Value *ScalarV = getSplatValue(Ptr)) {
Copy link
Collaborator

@topperc topperc Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if the splat instruction isn't in the same basic block? You have to be careful looking through instructions in SelectinDAGBuilder because the Value won't be exported from the producing basic block.

There used to be code to find the export that I removed in 944cc5e

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yuck, I should have checked the history here more closely. No, it doesn't.

I'm going to close this and revisit from scratch. I can probably find an analogous approach, but TBD.

@preames preames closed this Jun 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants