Skip to content

Commit 55f03c2

Browse files
committed
[LLVM-14] Recover non-struct pointee types from demangling func args
Generalize getParameterTypes to return Type* instead of StructType* so that primitive pointee types (e.g. bfloat16) are recovered from demangling function arguments, not just struct types. Add StructType* overload wrapper for backward compatibility. AI-assisted: Claude Sonnet 4.6 (commercial SaaS)
1 parent 30d6593 commit 55f03c2

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

lib/SPIRV/SPIRVInternal.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1024,9 +1024,12 @@ bool containsUnsignedAtomicType(StringRef Name);
10241024
std::string mangleBuiltin(StringRef UniqName, ArrayRef<Type *> ArgTypes,
10251025
BuiltinFuncMangleInfo *BtnInfo);
10261026

1027-
/// Extract the pointee types of arguments from a mangled function name. If the
1028-
/// corresponding type is not a pointer to a struct type, its value will be a
1029-
/// nullptr instead.
1027+
/// Extract the pointee types of arguments from a mangled function name,
1028+
/// including non-struct types such as bfloat16. Unknown types are nullptr.
1029+
void getParameterTypes(Function *F, SmallVectorImpl<Type *> &ArgTys);
1030+
1031+
/// Struct-typed variant of getParameterTypes. Non-struct pointee types (e.g.
1032+
/// bfloat) are returned as nullptr.
10301033
void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys);
10311034
inline void getParameterTypes(CallInst *CI,
10321035
SmallVectorImpl<StructType *> &ArgTys) {

lib/SPIRV/SPIRVUtil.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ static std::string demangleBuiltinOpenCLTypeName(StringRef MangledStructName) {
771771
return LlvmStructName;
772772
}
773773

774-
void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
774+
void getParameterTypes(Function *F, SmallVectorImpl<Type *> &ArgTys) {
775775
// If there's no mangled name, we can't do anything. Also, if there's no
776776
// parameters, do nothing.
777777
StringRef Name = F->getName();
@@ -791,7 +791,7 @@ void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
791791
assert(!HasSret && &Arg == F->getArg(0) &&
792792
"sret parameter should only appear on the first argument");
793793
HasSret = true;
794-
ArgTys.push_back(dyn_cast<StructType>(Ty));
794+
ArgTys.push_back(Ty);
795795
} else {
796796
ArgTys.push_back(nullptr);
797797
}
@@ -858,7 +858,7 @@ void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
858858
}
859859

860860
for (StringRef Arg : ArgParams) {
861-
StructType *Pointee = nullptr;
861+
Type *Pointee = nullptr;
862862
if (Arg.endswith("*") && !Arg.endswith("**")) {
863863
// Strip off address space and other qualifiers.
864864
StringRef MangledStructName = Arg.split(' ').first;
@@ -889,6 +889,13 @@ void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
889889
free(Buf);
890890
}
891891

892+
void getParameterTypes(Function *F, SmallVectorImpl<StructType *> &ArgTys) {
893+
SmallVector<Type *, 8> Tys;
894+
getParameterTypes(F, Tys);
895+
for (Type *T : Tys)
896+
ArgTys.push_back(dyn_cast_or_null<StructType>(T));
897+
}
898+
892899
CallInst *mutateCallInst(
893900
Module *M, CallInst *CI,
894901
std::function<std::string(CallInst *, std::vector<Value *> &)> ArgMutate,

0 commit comments

Comments
 (0)