Skip to content

Commit 7c1b948

Browse files
authored
[SPIR-V] Fix some GEP legalization (llvm#150943)
Pointers and GEP are untyped. SPIR-V required structured OpAccessChain. This means the backend will have to determine a good way to retrieve the structured access from an untyped GEP. This is not a trivial problem, and needs to be addressed to have a robust compiler. The issue is other workstreams relies on the access chain deduction to work. So we have 2 options: - pause all dependent work until we have a good chain deduction. - submit this limited fix to we can work on both this and other features in parallel. Choice we want to make is #2: submitting this **knowing this is not a good** fix. It only increase the number of patterns we can work with, thus allowing others to continue working on other parts of the backend. This patch as-is has many limitations: - If cannot robustly determine the depth of the structured access from a GEP. Fixing this would require looking ahead at the full GEP chain. - It cannot always figure out the correct access indices, especially with dynamic indices. This will require frontend collaboration. Because we know this is a temporary hack, this patch only impacts the logical SPIR-V target. Physical SPIR-V, which can rely on pointer cast remains on the old method. Related to llvm#145002
1 parent 09dc08b commit 7c1b948

File tree

6 files changed

+356
-5
lines changed

6 files changed

+356
-5
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,42 @@ class SPIRVEmitIntrinsics
194194

195195
void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
196196

197+
// Tries to walk the type accessed by the given GEP instruction.
198+
// For each nested type access, one of the 2 callbacks is called:
199+
// - OnLiteralIndexing when the index is a known constant value.
200+
// Parameters:
201+
// PointedType: the pointed type resulting of this indexing.
202+
// If the parent type is an array, this is the index in the array.
203+
// If the parent type is a struct, this is the field index.
204+
// Index: index of the element in the parent type.
205+
// - OnDynamnicIndexing when the index is a non-constant value.
206+
// This callback is only called when indexing into an array.
207+
// Parameters:
208+
// ElementType: the type of the elements stored in the parent array.
209+
// Offset: the Value* containing the byte offset into the array.
210+
// Return true if an error occured during the walk, false otherwise.
211+
bool walkLogicalAccessChain(
212+
GetElementPtrInst &GEP,
213+
const std::function<void(Type *PointedType, uint64_t Index)>
214+
&OnLiteralIndexing,
215+
const std::function<void(Type *ElementType, Value *Offset)>
216+
&OnDynamicIndexing);
217+
218+
// Returns the type accessed using the given GEP instruction by relying
219+
// on the GEP type.
220+
// FIXME: GEP types are not supposed to be used to retrieve the pointed
221+
// type. This must be fixed.
222+
Type *getGEPType(GetElementPtrInst *GEP);
223+
224+
// Returns the type accessed using the given GEP instruction by walking
225+
// the source type using the GEP indices.
226+
// FIXME: without help from the frontend, this method cannot reliably retrieve
227+
// the stored type, nor can robustly determine the depth of the type
228+
// we are accessing.
229+
Type *getGEPTypeLogical(GetElementPtrInst *GEP);
230+
231+
Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP);
232+
197233
public:
198234
static char ID;
199235
SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr)
@@ -246,6 +282,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
246282
}
247283
}
248284

285+
// Returns the source pointer from `I` ignoring intermediate ptrcast.
286+
Value *getPointerRoot(Value *I) {
287+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
288+
if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
289+
Value *V = II->getArgOperand(0);
290+
return getPointerRoot(V);
291+
}
292+
}
293+
return I;
294+
}
295+
249296
} // namespace
250297

251298
char SPIRVEmitIntrinsics::ID = 0;
@@ -555,7 +602,111 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
555602
Ty = RefTy;
556603
}
557604

558-
Type *getGEPType(GetElementPtrInst *Ref) {
605+
bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
606+
GetElementPtrInst &GEP,
607+
const std::function<void(Type *, uint64_t)> &OnLiteralIndexing,
608+
const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
609+
// We only rewrite i8* GEP. Other should be left as-is.
610+
// Valid i8* GEP must always have a single index.
611+
assert(GEP.getSourceElementType() ==
612+
IntegerType::getInt8Ty(CurrF->getContext()));
613+
assert(GEP.getNumIndices() == 1);
614+
615+
auto &DL = CurrF->getDataLayout();
616+
Value *Src = getPointerRoot(GEP.getPointerOperand());
617+
Type *CurType = deduceElementType(Src, true);
618+
619+
Value *Operand = *GEP.idx_begin();
620+
ConstantInt *CI = dyn_cast<ConstantInt>(Operand);
621+
if (!CI) {
622+
ArrayType *AT = dyn_cast<ArrayType>(CurType);
623+
// Operand is not constant. Either we have an array and accept it, or we
624+
// give up.
625+
if (AT)
626+
OnDynamicIndexing(AT->getElementType(), Operand);
627+
return AT == nullptr;
628+
}
629+
630+
assert(CI);
631+
uint64_t Offset = CI->getZExtValue();
632+
633+
do {
634+
if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
635+
uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
636+
assert(Offset < AT->getNumElements() * EltTypeSize);
637+
uint64_t Index = Offset / EltTypeSize;
638+
Offset = Offset - (Index * EltTypeSize);
639+
CurType = AT->getElementType();
640+
OnLiteralIndexing(CurType, Index);
641+
} else if (StructType *ST = dyn_cast<StructType>(CurType)) {
642+
uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
643+
assert(Offset < StructSize);
644+
const auto &STL = DL.getStructLayout(ST);
645+
unsigned Element = STL->getElementContainingOffset(Offset);
646+
Offset -= STL->getElementOffset(Element);
647+
CurType = ST->getElementType(Element);
648+
OnLiteralIndexing(CurType, Element);
649+
} else {
650+
// Vector type indexing should not use GEP.
651+
// So if we have an index left, something is wrong. Giving up.
652+
return true;
653+
}
654+
} while (Offset > 0);
655+
656+
return false;
657+
}
658+
659+
Instruction *
660+
SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
661+
auto &DL = CurrF->getDataLayout();
662+
IRBuilder<> B(GEP.getParent());
663+
B.SetInsertPoint(&GEP);
664+
665+
std::vector<Value *> Indices;
666+
Indices.push_back(ConstantInt::get(
667+
IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false));
668+
walkLogicalAccessChain(
669+
GEP,
670+
[&Indices, &B](Type *EltType, uint64_t Index) {
671+
Indices.push_back(
672+
ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
673+
},
674+
[&Indices, &B, &DL](Type *EltType, Value *Offset) {
675+
uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8;
676+
Value *Index = B.CreateUDiv(
677+
Offset, ConstantInt::get(Offset->getType(), EltTypeSize,
678+
/* Signed= */ false));
679+
Indices.push_back(Index);
680+
});
681+
682+
SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
683+
SmallVector<Value *, 4> Args;
684+
Args.push_back(B.getInt1(GEP.isInBounds()));
685+
Args.push_back(GEP.getOperand(0));
686+
llvm::append_range(Args, Indices);
687+
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
688+
replaceAllUsesWithAndErase(B, &GEP, NewI);
689+
return NewI;
690+
}
691+
692+
Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) {
693+
694+
Type *CurType = GEP->getResultElementType();
695+
696+
bool Interrupted = walkLogicalAccessChain(
697+
*GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; },
698+
[&CurType](Type *EltType, Value *Index) { CurType = EltType; });
699+
700+
return Interrupted ? GEP->getResultElementType() : CurType;
701+
}
702+
703+
Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) {
704+
if (Ref->getSourceElementType() ==
705+
IntegerType::getInt8Ty(CurrF->getContext()) &&
706+
TM->getSubtargetImpl()->isLogicalSPIRV()) {
707+
return getGEPTypeLogical(Ref);
708+
}
709+
559710
Type *Ty = nullptr;
560711
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
561712
// useful here
@@ -1395,6 +1546,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
13951546
}
13961547

13971548
Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
1549+
if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) &&
1550+
TM->getSubtargetImpl()->isLogicalSPIRV()) {
1551+
Instruction *Result = buildLogicalAccessChainFromGEP(I);
1552+
if (Result)
1553+
return Result;
1554+
}
1555+
13981556
IRBuilder<> B(I.getParent());
13991557
B.SetInsertPoint(&I);
14001558
SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()};
@@ -1588,7 +1746,24 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
15881746
}
15891747
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
15901748
Value *Pointer = GEPI->getPointerOperand();
1591-
Type *OpTy = GEPI->getSourceElementType();
1749+
Type *OpTy = nullptr;
1750+
1751+
// Knowing the accessed type is mandatory for logical SPIR-V. Sadly,
1752+
// the GEP source element type should not be used for this purpose, and
1753+
// the alternative type-scavenging method is not working.
1754+
// Physical SPIR-V can work around this, but not logical, hence still
1755+
// try to rely on the broken type scavenging for logical.
1756+
bool IsRewrittenGEP =
1757+
GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext());
1758+
if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) {
1759+
Value *Src = getPointerRoot(Pointer);
1760+
OpTy = GR->findDeducedElementType(Src);
1761+
}
1762+
1763+
// In all cases, fall back to the GEP type if type scavenging failed.
1764+
if (!OpTy)
1765+
OpTy = GEPI->getSourceElementType();
1766+
15921767
replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
15931768
if (isNestedPointer(OpTy))
15941769
insertTodoType(Pointer);

llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ define spir_func void @foo(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
3333
%RoundedRangeKernel = alloca %tprange, align 8
3434
call void @llvm.lifetime.start.p0(i64 72, ptr nonnull %RoundedRangeKernel)
3535
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
36-
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
36+
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
3737
call void @llvm.lifetime.end.p0(i64 72, ptr nonnull %RoundedRangeKernel)
3838
ret void
3939
}
@@ -55,7 +55,7 @@ define spir_func void @bar(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
5555
%RoundedRangeKernel = alloca %tprange, align 8
5656
call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
5757
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
58-
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
58+
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
5959
call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
6060
ret void
6161
}

llvm/test/CodeGen/SPIRV/logical-struct-access.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
1+
; RUN: llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -print-after-all | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
23

34
; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0
45

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
5+
%struct.S2 = type { <4 x float>, <4 x i32> }
6+
7+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
8+
9+
define <4 x float> @main() {
10+
entry:
11+
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
12+
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
13+
14+
; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
15+
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
16+
; CHECK-DAG: %[[#ulong_3:]] = OpConstant %[[#ulong]] 3
17+
18+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
19+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
20+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
21+
22+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
23+
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
24+
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
25+
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
26+
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
27+
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
28+
29+
; CHECK: %[[#tmp:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
30+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#tmp]] %[[#ulong_1]] %[[#ulong_3]]
31+
; This rewritten GEP combined all constant indices into a single value.
32+
; We should make sure the correct indices are retrieved.
33+
%arrayidx.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 64
34+
35+
; CHECK: OpLoad %[[#v4f]] %[[#ptr]]
36+
%4 = load <4 x float>, ptr addrspace(11) %arrayidx.i, align 1
37+
38+
ret <4 x float> %4
39+
}
40+
41+
declare i32 @llvm.spv.flattened.thread.id.in.group()
42+
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
43+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
44+
45+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
46+
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
5+
%struct.S2 = type { <4 x float>, <4 x i32> }
6+
7+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
8+
9+
define <4 x float> @main(i32 %index) {
10+
entry:
11+
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
12+
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
13+
14+
; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
15+
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
16+
17+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
18+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
19+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
20+
; CHECK-DAG: %[[#uint_16:]] = OpConstant %[[#uint]] 16
21+
22+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
23+
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
24+
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
25+
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
26+
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
27+
; CHECK-DAG: %[[#sb_arr_v4f:]] = OpTypePointer StorageBuffer %[[#arr_v4f]]
28+
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
29+
30+
; CHECK: %[[#a:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
31+
; CHECK: %[[#b:]] = OpInBoundsAccessChain %[[#sb_arr_v4f]] %[[#a]] %[[#ulong_1]]
32+
%4 = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16
33+
34+
; CHECK: %[[#offset:]] = OpIMul %[[#]] %[[#]] %[[#uint_16]]
35+
; Offset is computed in bytes. Make sure we reconvert it back to an index.
36+
%offset = mul i32 %index, 16
37+
38+
; CHECK: %[[#index:]] = OpUDiv %[[#]] %[[#offset]] %[[#uint_16]]
39+
; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#b]] %[[#index]]
40+
%5 = getelementptr inbounds nuw i8, ptr addrspace(11) %4, i32 %offset
41+
42+
; CHECK: OpLoad %[[#v4f]] %[[#c]]
43+
%6 = load <4 x float>, ptr addrspace(11) %5, align 1
44+
45+
ret <4 x float> %6
46+
}
47+
48+
declare i32 @llvm.spv.flattened.thread.id.in.group()
49+
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
50+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
51+
52+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
53+
54+

0 commit comments

Comments
 (0)