Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 76 additions & 102 deletions llvm/lib/CodeGen/CodeGenPrepare.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6230,7 +6230,7 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,

/// Rewrite GEP input to gather/scatter to enable SelectionDAGBuilder to find
/// a uniform base to use for ISD::MGATHER/MSCATTER. SelectionDAGBuilder can
/// only handle a 2 operand GEP in the same basic block or a splat constant
/// only handle a 2 operand GEP in the same basic block or a canonical splat
/// vector. The 2 operands to the GEP must have a scalar pointer and a vector
/// index.
///
Expand All @@ -6247,124 +6247,98 @@ bool CodeGenPrepare::optimizeMemoryInst(Instruction *MemoryInst, Value *Addr,
/// zero index.
bool CodeGenPrepare::optimizeGatherScatterInst(Instruction *MemoryInst,
Value *Ptr) {
Value *NewAddr;
const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr);
if (!GEP)
return false;

if (const auto *GEP = dyn_cast<GetElementPtrInst>(Ptr)) {
// Don't optimize GEPs that don't have indices.
if (!GEP->hasIndices())
return false;
// Don't optimize GEPs that don't have indices.
if (!GEP->hasIndices())
return false;

// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
// FIXME: We should support this by sinking the GEP.
if (MemoryInst->getParent() != GEP->getParent())
return false;
// If the GEP and the gather/scatter aren't in the same BB, don't optimize.
// FIXME: We should support this by sinking the GEP.
if (MemoryInst->getParent() != GEP->getParent())
return false;

SmallVector<Value *, 2> Ops(GEP->operands());
SmallVector<Value *, 2> Ops(GEP->operands());

bool RewriteGEP = false;
bool RewriteGEP = false;

if (Ops[0]->getType()->isVectorTy()) {
Ops[0] = getSplatValue(Ops[0]);
if (!Ops[0])
return false;
RewriteGEP = true;
}
if (Ops[0]->getType()->isVectorTy()) {
Ops[0] = getSplatValue(Ops[0]);
if (!Ops[0])
return false;
RewriteGEP = true;
}

unsigned FinalIndex = Ops.size() - 1;
unsigned FinalIndex = Ops.size() - 1;

// Ensure all but the last index is 0.
// FIXME: This isn't strictly required. All that's required is that they are
// all scalars or splats.
for (unsigned i = 1; i < FinalIndex; ++i) {
auto *C = dyn_cast<Constant>(Ops[i]);
if (!C)
return false;
if (isa<VectorType>(C->getType()))
C = C->getSplatValue();
auto *CI = dyn_cast_or_null<ConstantInt>(C);
if (!CI || !CI->isZero())
return false;
// Scalarize the index if needed.
Ops[i] = CI;
}

// Try to scalarize the final index.
if (Ops[FinalIndex]->getType()->isVectorTy()) {
if (Value *V = getSplatValue(Ops[FinalIndex])) {
auto *C = dyn_cast<ConstantInt>(V);
// Don't scalarize all zeros vector.
if (!C || !C->isZero()) {
Ops[FinalIndex] = V;
RewriteGEP = true;
}
// Ensure all but the last index is 0.
// FIXME: This isn't strictly required. All that's required is that they are
// all scalars or splats.
for (unsigned i = 1; i < FinalIndex; ++i) {
auto *C = dyn_cast<Constant>(Ops[i]);
if (!C)
return false;
if (isa<VectorType>(C->getType()))
C = C->getSplatValue();
auto *CI = dyn_cast_or_null<ConstantInt>(C);
if (!CI || !CI->isZero())
return false;
// Scalarize the index if needed.
Ops[i] = CI;
}

// Try to scalarize the final index.
if (Ops[FinalIndex]->getType()->isVectorTy()) {
if (Value *V = getSplatValue(Ops[FinalIndex])) {
auto *C = dyn_cast<ConstantInt>(V);
// Don't scalarize all zeros vector.
if (!C || !C->isZero()) {
Ops[FinalIndex] = V;
RewriteGEP = true;
}
}
}

// If we made any changes or the we have extra operands, we need to generate
// new instructions.
if (!RewriteGEP && Ops.size() == 2)
return false;
// If we made any changes or the we have extra operands, we need to generate
// new instructions.
if (!RewriteGEP && Ops.size() == 2)
return false;

auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();

IRBuilder<> Builder(MemoryInst);
IRBuilder<> Builder(MemoryInst);

Type *SourceTy = GEP->getSourceElementType();
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());
Type *SourceTy = GEP->getSourceElementType();
Type *ScalarIndexTy = DL->getIndexType(Ops[0]->getType()->getScalarType());

// If the final index isn't a vector, emit a scalar GEP containing all ops
// and a vector GEP with all zeroes final index.
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
auto *SecondTy = GetElementPtrInst::getIndexedType(
SourceTy, ArrayRef(Ops).drop_front());
NewAddr =
Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
} else {
Value *Base = Ops[0];
Value *Index = Ops[FinalIndex];

// Create a scalar GEP if there are more than 2 operands.
if (Ops.size() != 2) {
// Replace the last index with 0.
Ops[FinalIndex] =
Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
SourceTy = GetElementPtrInst::getIndexedType(
SourceTy, ArrayRef(Ops).drop_front());
}
// If the final index isn't a vector, emit a scalar GEP containing all ops
// and a vector GEP with all zeroes final index.
Value *NewAddr;
if (!Ops[FinalIndex]->getType()->isVectorTy()) {
NewAddr = Builder.CreateGEP(SourceTy, Ops[0], ArrayRef(Ops).drop_front());
auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
auto *SecondTy =
GetElementPtrInst::getIndexedType(SourceTy, ArrayRef(Ops).drop_front());
NewAddr =
Builder.CreateGEP(SecondTy, NewAddr, Constant::getNullValue(IndexTy));
} else {
Value *Base = Ops[0];
Value *Index = Ops[FinalIndex];

// Now create the GEP with scalar pointer and vector index.
NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
// Create a scalar GEP if there are more than 2 operands.
if (Ops.size() != 2) {
// Replace the last index with 0.
Ops[FinalIndex] =
Constant::getNullValue(Ops[FinalIndex]->getType()->getScalarType());
Base = Builder.CreateGEP(SourceTy, Base, ArrayRef(Ops).drop_front());
SourceTy = GetElementPtrInst::getIndexedType(SourceTy,
ArrayRef(Ops).drop_front());
}
} else if (!isa<Constant>(Ptr)) {
// Not a GEP, maybe its a splat and we can create a GEP to enable
// SelectionDAGBuilder to use it as a uniform base.
Value *V = getSplatValue(Ptr);
if (!V)
return false;

auto NumElts = cast<VectorType>(Ptr->getType())->getElementCount();

IRBuilder<> Builder(MemoryInst);

// Emit a vector GEP with a scalar pointer and all 0s vector index.
Type *ScalarIndexTy = DL->getIndexType(V->getType()->getScalarType());
auto *IndexTy = VectorType::get(ScalarIndexTy, NumElts);
Type *ScalarTy;
if (cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
Intrinsic::masked_gather) {
ScalarTy = MemoryInst->getType()->getScalarType();
} else {
assert(cast<IntrinsicInst>(MemoryInst)->getIntrinsicID() ==
Intrinsic::masked_scatter);
ScalarTy = MemoryInst->getOperand(0)->getType()->getScalarType();
}
NewAddr = Builder.CreateGEP(ScalarTy, V, Constant::getNullValue(IndexTy));
} else {
// Constant, SelectionDAGBuilder knows to check if its a splat.
return false;
// Now create the GEP with scalar pointer and vector index.
NewAddr = Builder.CreateGEP(SourceTy, Base, Index);
}

MemoryInst->replaceUsesOfWith(Ptr, NewAddr);
Expand Down
11 changes: 3 additions & 8 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4888,14 +4888,9 @@ static bool getUniformBase(const Value *Ptr, SDValue &Base, SDValue &Index,

assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type");

// Handle splat constant pointer.
if (auto *C = dyn_cast<Constant>(Ptr)) {
C = C->getSplatValue();
if (!C)
return false;

Base = SDB->getValue(C);

// Handle splat (possibly constant) pointer.
if (Value *ScalarV = getSplatValue(Ptr)) {
Copy link
Collaborator

@topperc topperc Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work if the splat instruction isn't in the same basic block? You have to be careful looking through instructions in SelectinDAGBuilder because the Value won't be exported from the producing basic block.

There used to be code to find the export that I removed in 944cc5e

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yuck, I should have checked the history here more closely. No, it doesn't.

I'm going to close this and revisit from scratch. I can probably find an analogous approach, but TBD.

Base = SDB->getValue(ScalarV);
ElementCount NumElts = cast<VectorType>(Ptr->getType())->getElementCount();
EVT VT = EVT::getVectorVT(*DAG.getContext(), TLI.getPointerTy(DL), NumElts);
Index = DAG.getConstant(0, SDB->getCurSDLoc(), VT);
Expand Down
8 changes: 4 additions & 4 deletions llvm/test/CodeGen/X86/masked_gather_scatter.ll
Original file line number Diff line number Diff line change
Expand Up @@ -4475,7 +4475,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0,4), %xmm1 {%k1}
; X64-SKX-NEXT: vpgatherdd (%rdi,%xmm0), %xmm1 {%k1}
; X64-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X64-SKX-NEXT: retq
;
Expand All @@ -4485,7 +4485,7 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0,4), %xmm1 {%k1}
; X86-SKX-NEXT: vpgatherdd (%eax,%xmm0), %xmm1 {%k1}
; X86-SKX-NEXT: vmovdqa %xmm1, %xmm0
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
Expand Down Expand Up @@ -4581,7 +4581,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X64-SKX-NEXT: vpslld $31, %xmm0, %xmm0
; X64-SKX-NEXT: vpmovd2m %xmm0, %k1
; X64-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0,4) {%k1}
; X64-SKX-NEXT: vpscatterdd %xmm1, (%rdi,%xmm0) {%k1}
; X64-SKX-NEXT: retq
;
; X86-SKX-LABEL: splat_ptr_scatter:
Expand All @@ -4590,7 +4590,7 @@ define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; X86-SKX-NEXT: vpmovd2m %xmm0, %k1
; X86-SKX-NEXT: movl {{[0-9]+}}(%esp), %eax
; X86-SKX-NEXT: vpxor %xmm0, %xmm0, %xmm0
; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0,4) {%k1}
; X86-SKX-NEXT: vpscatterdd %xmm1, (%eax,%xmm0) {%k1}
; X86-SKX-NEXT: retl
%1 = insertelement <4 x ptr> undef, ptr %ptr, i32 0
%2 = shufflevector <4 x ptr> %1, <4 x ptr> undef, <4 x i32> zeroinitializer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {

define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
Expand All @@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <

define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> poison, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> poison, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ define <vscale x 4 x i32> @global_struct_splat(<vscale x 4 x i1> %mask) #0 {

define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %passthru) #0 {
; CHECK-LABEL: @splat_ptr_gather(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP3]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <vscale x 4 x i32> @llvm.masked.gather.nxv4i32.nxv4p0(<vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]], <vscale x 4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <vscale x 4 x i32> [[TMP2]]
;
Expand All @@ -97,7 +98,8 @@ define <vscale x 4 x i32> @splat_ptr_gather(ptr %ptr, <vscale x 4 x i1> %mask, <

define void @splat_ptr_scatter(ptr %ptr, <vscale x 4 x i1> %mask, <vscale x 4 x i32> %val) #0 {
; CHECK-LABEL: @splat_ptr_scatter(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <vscale x 4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <vscale x 4 x ptr> undef, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <vscale x 4 x ptr> [[TMP2]], <vscale x 4 x ptr> undef, <vscale x 4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.nxv4i32.nxv4p0(<vscale x 4 x i32> [[VAL:%.*]], <vscale x 4 x ptr> [[TMP1]], i32 4, <vscale x 4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ define <4 x i32> @global_struct_splat() {

define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
Expand All @@ -100,7 +101,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru

define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> poison, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> poison, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
Expand Down
6 changes: 4 additions & 2 deletions llvm/test/Transforms/CodeGenPrepare/X86/gather-scatter-opt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ define <4 x i32> @global_struct_splat() {

define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru) {
; CHECK-LABEL: @splat_ptr_gather(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP3:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP3]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = call <4 x i32> @llvm.masked.gather.v4i32.v4p0(<4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]], <4 x i32> [[PASSTHRU:%.*]])
; CHECK-NEXT: ret <4 x i32> [[TMP2]]
;
Expand All @@ -99,7 +100,8 @@ define <4 x i32> @splat_ptr_gather(ptr %ptr, <4 x i1> %mask, <4 x i32> %passthru

define void @splat_ptr_scatter(ptr %ptr, <4 x i1> %mask, <4 x i32> %val) {
; CHECK-LABEL: @splat_ptr_scatter(
; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i32, ptr [[PTR:%.*]], <4 x i64> zeroinitializer
; CHECK-NEXT: [[TMP2:%.*]] = insertelement <4 x ptr> undef, ptr [[PTR:%.*]], i32 0
; CHECK-NEXT: [[TMP1:%.*]] = shufflevector <4 x ptr> [[TMP2]], <4 x ptr> undef, <4 x i32> zeroinitializer
; CHECK-NEXT: call void @llvm.masked.scatter.v4i32.v4p0(<4 x i32> [[VAL:%.*]], <4 x ptr> [[TMP1]], i32 4, <4 x i1> [[MASK:%.*]])
; CHECK-NEXT: ret void
;
Expand Down
Loading