Skip to content

Commit 8001ca2

Browse files
[Backport to 18][LLVM->SPIRV] Cast the GEP base pointer to source type upon mismatch (#3255) (#3642)
The source element type used in a GEP may differ from the actual type of the pointer operand (e.g., ptr i8 vs. ptr [N x T]). This mismatch can lead to incorrect address computations during translation to SPIR-V of GEP used in constexpr context, which requires that pointer types match the type of the object being accessed. This patch inserts an explicit bitcast to convert the GEP pointer operand to the expected type, derived from the GEP’s source element type, before emitting an PtrAccessChain. This ensures the resulting SPIR-V instruction has a correctly typed base pointer and produces valid indexing behavior. For example: Before this change, the following GEP was translated incorrectly: getelementptr(i8, ptr addrspace(1) @a_var, i64 2) Whereas this nearly equivalent GEP was handled correctly: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 1) Previously, the first form was incorrectly interpreted as: getelementptr inbounds ([2 x i8], ptr @a_var, i64 0, i64 2) (cherry picked from commit 1be9366) Co-authored-by: Karol Zwolak <karolzwolak7@gmail.com>
1 parent 7171801 commit 8001ca2

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1449,6 +1449,19 @@ SPIRVValue *LLVMToSPIRVBase::transConstant(Value *V) {
14491449
for (unsigned I = 0, E = GEP->getNumIndices(); I != E; ++I)
14501450
Indices.push_back(transValue(GEP->getOperand(I + 1), nullptr));
14511451
auto *TransPointerOperand = transValue(GEP->getPointerOperand(), nullptr);
1452+
1453+
// Determine the expected pointer type from the GEP source element type.
1454+
Type *GepSourceElemTy = GEP->getSourceElementType();
1455+
SPIRVType *ExpectedPtrTy =
1456+
transPointerType(GepSourceElemTy, GEP->getPointerAddressSpace());
1457+
1458+
// Ensure the base pointer's type matches the GEP's effective source
1459+
// element type.
1460+
if (TransPointerOperand->getType() != ExpectedPtrTy) {
1461+
TransPointerOperand = BM->addUnaryInst(OpBitcast, ExpectedPtrTy,
1462+
TransPointerOperand, nullptr);
1463+
}
1464+
14521465
SPIRVType *TranslatedTy = transScavengedType(GEP);
14531466
return BM->addPtrAccessChainInst(TranslatedTy, TransPointerOperand,
14541467
Indices, nullptr, GEP->isInBounds());
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
; RUN: llvm-as %s -o %t.bc
2+
; RUN: llvm-spirv %t.bc -o %t.spv
3+
; RUN: spirv-val %t.spv
4+
; RUN: llvm-spirv %t.spv -o %t.spt --to-text
5+
; RUN: llvm-spirv -r %t.spv -o %t.rev.bc
6+
; RUN: llvm-dis %t.rev.bc
7+
; RUN: FileCheck %s --input-file %t.spt -check-prefix=CHECK-SPIRV
8+
; RUN: FileCheck %s --input-file %t.rev.ll -check-prefix=CHECK-LLVM
9+
10+
; Make sure that when the GEP operand type doesn't match the source element type (here operand a_var is [2 x i16], but the source element is i8),
11+
; we cast the operand to the source element pointer type (a_var to i8*).
12+
13+
target datalayout = "e-p:32:32-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024"
14+
target triple = "spir-unknown-unknown"
15+
16+
; CHECK-SPIRV-DAG: Name [[A_VAR:[0-9]+]] "a_var"
17+
; CHECK-SPIRV-DAG: Name [[GLOBAL_PTR:[0-9]+]] "global_ptr"
18+
19+
; CHECK-SPIRV-DAG: TypeArray [[ARRAY_TYPE:[0-9]+]] [[USHORT_TYPE:[0-9]+]] [[CONST_2:[0-9]+]]
20+
; CHECK-SPIRV-DAG: TypePointer [[ARRAY_PTR_TYPE:[0-9]+]] 5 [[ARRAY_TYPE]]
21+
22+
; CHECK-SPIRV-DAG: Variable [[ARRAY_PTR_TYPE]] [[A_VAR]] 5 [[INIT_ID:[0-9]+]]
23+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR:[0-9]+]] [[BITCAST:[0-9]+]] 124 [[A_VAR]]
24+
; CHECK-SPIRV-DAG: SpecConstantOp [[I8PTR]] [[PTRCHAIN:[0-9]+]] 67 [[BITCAST]] [[INDEX_ID:[0-9]+]]
25+
; CHECK-SPIRV-DAG: TypePointer [[PTR_PTR_TYPE:[0-9]+]] 5 [[I8PTR]]
26+
; CHECK-SPIRV-DAG: Variable [[PTR_PTR_TYPE]] [[GLOBAL_PTR]] 5 [[PTRCHAIN]]
27+
28+
; CHECK-LLVM: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8
29+
; CHECK-LLVM-NOT: @global_ptr = addrspace(1) global ptr addrspace(1) getelementptr ([2 x i16], ptr addrspace(1) @a_var, i64 2), align 8
30+
31+
@a_var = dso_local addrspace(1) global [2 x i16] [i16 4, i16 5], align 2
32+
@global_ptr = dso_local addrspace(1) global ptr addrspace(1) getelementptr (i8, ptr addrspace(1) @a_var, i64 2), align 8

0 commit comments

Comments
 (0)