Skip to content

Commit 1e42a1b

Browse files
karolzwolakvmaksimo
authored andcommitted
Cast the GEP base pointer to source type upon mismatch (#3255)
The source element type used in a GEP may differ from the actual type of the pointer operand (e.g., ptr i8 vs. ptr [N x T]). This mismatch can lead to incorrect address computations during translation to SPIR-V of GEP used in constexpr context, which requires that pointer types match the type of the object being accessed. This patch inserts an explicit bitcast to convert the GEP pointer operand to the expected type, derived from the GEP’s source element type, before emitting an PtrAccessChain. This ensures the resulting SPIR-V instruction has a correctly typed base pointer and produces valid indexing behavior. For example: Before this change, the following GEP was translated incorrectly: getelementptr(i8, ptr addrspace(1) @a_var, i64 2) Whereas this nearly equivalent GEP was handled correctly: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 1) Previously, the first form was incorrectly interpreted as: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 2) Original commit: KhronosGroup/SPIRV-LLVM-Translator@1be936678fd8bbe
1 parent 115b84e commit 1e42a1b

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,6 +1480,18 @@ SPIRVValue *LLVMToSPIRVBase::transConstant(Value *V) {
14801480
if (auto *ConstUE = dyn_cast<ConstantExpr>(V)) {
14811481
if (auto *GEP = dyn_cast<GEPOperator>(ConstUE)) {
14821482
auto *TransPointerOperand = transValue(GEP->getPointerOperand(), nullptr);
1483+
// Determine the expected pointer type from the GEP source element type.
1484+
Type *GepSourceElemTy = GEP->getSourceElementType();
1485+
SPIRVType *ExpectedPtrTy =
1486+
transPointerType(GepSourceElemTy, GEP->getPointerAddressSpace());
1487+
1488+
// Ensure the base pointer's type matches the GEP's effective source
1489+
// element type.
1490+
if (TransPointerOperand->getType() != ExpectedPtrTy) {
1491+
TransPointerOperand = BM->addUnaryInst(OpBitcast, ExpectedPtrTy,
1492+
TransPointerOperand, nullptr);
1493+
}
1494+
14831495
std::vector<SPIRVWord> Ops = {TransPointerOperand->getId()};
14841496
for (unsigned I = 0, E = GEP->getNumIndices(); I != E; ++I)
14851497
Ops.push_back(transValue(GEP->getOperand(I + 1), nullptr)->getId());
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc
7+
; RUN: FileCheck %s --input-file %t.spt -check-prefix=CHECK-SPIRV
8+
; RUN: FileCheck %s --input-file %t.rev.ll -check-prefix=CHECK-LLVM
9+
10+
; Make sure that when the GEP operand type doesn't match the source element type (here operand a_var is [2 x i16], but the source element is i8),
11+
; we cast the operand to the source element pointer type (a_var to i8*).
12+
13+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
14+
target triple = "spir-unknown-unknown"
15+
16+
; CHECK-SPIRV-DAG: Name [[A_VAR:[0-9]+]] "a_var"
17+
; CHECK-SPIRV-DAG: Name [[GLOBAL_PTR:[0-9]+]] "global_ptr"
18+
19+
; CHECK-SPIRV-DAG: TypeArray [[ARRAY_TYPE:[0-9]+]] [[USHORT_TYPE:[0-9]+]] [[CONST_2:[0-9]+]]
20+
; CHECK-SPIRV-DAG: TypePointer [[ARRAY_PTR_TYPE:[0-9]+]] 5 [[ARRAY_TYPE]]
21+
22+
; CHECK-SPIRV-DAG: Variable [[ARRAY_PTR_TYPE]] [[A_VAR]] 5 [[INIT_ID:[0-9]+]]
23+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR:[0-9]+]] [[BITCAST:[0-9]+]] 124 [[A_VAR]]
24+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR]] [[PTRCHAIN:[0-9]+]] 67 [[BITCAST]] [[INDEX_ID:[0-9]+]]
25+
; CHECK-SPIRV-DAG: TypePointer [[PTR_PTR_TYPE:[0-9]+]] 5 [[I8PTR]]
26+
; CHECK-SPIRV-DAG: Variable [[PTR_PTR_TYPE]] [[GLOBAL_PTR]] 5 [[PTRCHAIN]]
27+
28+
; CHECK-LLVM: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8
29+
; CHECK-LLVM-NOT: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr ([2 x i16], ptr addrspace(1) @a_var, i64 2), align 8
30+
31+
@a_var = dso_local addrspace(1) global [2 x i16] [i16 4, i16 5], align 2
32+
@global_ptr = dso_local addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8

0 commit comments

Comments
 (0)