Skip to content

Commit 513334f

Browse files
authored
[NFC][SPIRV] Fix function type recovery (llvm#165934)
Due to limitations in GISel / IRTranslator, the SPIR-V BE replaces aggregate function args with `i32` placeholders, which are subsequently used to retrieve the original type after IR translation, from metadata. Due to what appears to be an oversight, the current implementation only handles a single mutation, as it does not traverse the metadata, but rather only takes the first operand. This patch addresses that limitation by correctly iterating the metadata.
1 parent 6a275de commit 513334f

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

llvm/lib/Target/SPIRV/SPIRVCallLowering.cpp

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,23 +149,23 @@ static FunctionType *getOriginalFunctionType(const Function &F) {
149149
return isa<MDString>(N->getOperand(0)) &&
150150
cast<MDString>(N->getOperand(0))->getString() == F.getName();
151151
});
152-
// TODO: probably one function can have numerous type mutations,
153-
// so we should support this.
154152
if (ThisFuncMDIt != NamedMD->op_end()) {
155153
auto *ThisFuncMD = *ThisFuncMDIt;
156-
MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
157-
assert(MD && "MDNode operand is expected");
158-
ConstantInt *Const = getConstInt(MD, 0);
159-
if (Const) {
160-
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
161-
assert(CMeta && "ConstantAsMetadata operand is expected");
162-
assert(Const->getSExtValue() >= -1);
163-
// Currently -1 indicates return value, greater values mean
164-
// argument numbers.
165-
if (Const->getSExtValue() == -1)
166-
RetTy = CMeta->getType();
167-
else
168-
ArgTypes[Const->getSExtValue()] = CMeta->getType();
154+
for (unsigned I = 1; I != ThisFuncMD->getNumOperands(); ++I) {
155+
MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(I));
156+
assert(MD && "MDNode operand is expected");
157+
ConstantInt *Const = getConstInt(MD, 0);
158+
if (Const) {
159+
auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
160+
assert(CMeta && "ConstantAsMetadata operand is expected");
161+
assert(Const->getSExtValue() >= -1);
162+
// Currently -1 indicates return value, greater values mean
163+
// argument numbers.
164+
if (Const->getSExtValue() == -1)
165+
RetTy = CMeta->getType();
166+
else
167+
ArgTypes[Const->getSExtValue()] = CMeta->getType();
168+
}
169169
}
170170
}
171171

llvm/test/CodeGen/SPIRV/pointers/composite-fun-fix-ptr-arg.ll

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,28 @@
1010

1111
; CHECK-DAG: %[[#Int8:]] = OpTypeInt 8 0
1212
; CHECK-DAG: %[[#Half:]] = OpTypeFloat 16
13+
; CHECK-DAG: %[[#Float:]] = OpTypeFloat 32
1314
; CHECK-DAG: %[[#Struct:]] = OpTypeStruct %[[#Half]]
1415
; CHECK-DAG: %[[#Void:]] = OpTypeVoid
1516
; CHECK-DAG: %[[#PtrInt8:]] = OpTypePointer CrossWorkgroup %[[#Int8:]]
1617
; CHECK-DAG: %[[#FooType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt8]] %[[#Struct]]
1718
; CHECK-DAG: %[[#Int64:]] = OpTypeInt 64 0
1819
; CHECK-DAG: %[[#PtrInt64:]] = OpTypePointer CrossWorkgroup %[[#Int64]]
1920
; CHECK-DAG: %[[#BarType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt64]] %[[#Struct]]
21+
; CHECK-DAG: %[[#BazType:]] = OpTypeFunction %[[#Void]] %[[#PtrInt8]] %[[#Struct]] %[[#Int8]] %[[#Struct]] %[[#Float]] %[[#Struct]]
2022
; CHECK: OpFunction %[[#Void]] None %[[#FooType]]
2123
; CHECK: OpFunctionParameter %[[#PtrInt8]]
2224
; CHECK: OpFunctionParameter %[[#Struct]]
2325
; CHECK: OpFunction %[[#Void]] None %[[#BarType]]
2426
; CHECK: OpFunctionParameter %[[#PtrInt64]]
2527
; CHECK: OpFunctionParameter %[[#Struct]]
28+
; CHECK: OpFunction %[[#Void]] None %[[#BazType]]
29+
; CHECK: OpFunctionParameter %[[#PtrInt8]]
30+
; CHECK: OpFunctionParameter %[[#Struct]]
31+
; CHECK: OpFunctionParameter %[[#Int8]]
32+
; CHECK: OpFunctionParameter %[[#Struct]]
33+
; CHECK: OpFunctionParameter %[[#Float]]
34+
; CHECK: OpFunctionParameter %[[#Struct]]
2635

2736
%t_half = type { half }
2837

@@ -38,4 +47,9 @@ entry:
3847
ret void
3948
}
4049

50+
define spir_kernel void @baz(ptr addrspace(1) %a, %t_half %b, i8 %c, %t_half %d, float %e, %t_half %f) {
51+
entry:
52+
ret void
53+
}
54+
4155
declare spir_func %t_half @_Z29__spirv_SpecConstantComposite(half)

0 commit comments

Comments
 (0)