-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[clang][WebAssembly] Handle casted function pointers with different number of arguments #153168
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
b51657b
to
3c2b6dc
Compare
3c2b6dc
to
1960925
Compare
1960925
to
c17adfa
Compare
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-backend-webassembly Author: Jorge Zapata (turran) ChangesThis overcomes the limitation of WebAssemblyFixFunctionBitcasts LLVM module after the introduction of Opaque Pointers. The function pointers no longer have a prototype; therefore, the solution has to be done at clang. This solves many open issues, like:
Full diff: https://github.com/llvm/llvm-project/pull/153168.diff 6 Files Affected:
diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp
index b959982809911..8eb3afee08ff5 100644
--- a/clang/lib/CodeGen/CGCall.cpp
+++ b/clang/lib/CodeGen/CGCall.cpp
@@ -4932,6 +4932,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
return;
}
+ // For WebAssembly target we need to create thunk functions
+ // to properly handle function pointers args with a different signature.
+ // Due to opaque pointers, this can not be handled in LLVM
+ // (WebAssemblyFixFunctionBitcast) anymore
+ if (CGM.getTriple().isWasm() && type->isFunctionPointerType()) {
+ if (const DeclRefExpr *DRE =
+ CGM.getTargetCodeGenInfo().getWasmFunctionDeclRefExpr(
+ E, CGM.getContext())) {
+ llvm::Value *V = EmitLValue(DRE).getPointer(*this);
+ llvm::Function *Thunk =
+ CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
+ CGM, V, DRE->getDecl()->getType(), type);
+ if (Thunk) {
+ RValue R = RValue::get(Thunk);
+ args.add(R, type);
+ return;
+ }
+ }
+ }
+
args.add(EmitAnyExprToTemp(E), type);
}
diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp
index a96c1518d2a1d..956f780ad71e3 100644
--- a/clang/lib/CodeGen/CGExprConstant.cpp
+++ b/clang/lib/CodeGen/CGExprConstant.cpp
@@ -2243,6 +2243,19 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {
if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
llvm::Constant *C = CGM.getRawFunctionPointer(FD);
+ // ForWebAssembly target we need to create thunk functions
+ // to properly handle function pointers args with a different signature
+ // Due to opaque pointers, this can not be handled in LLVM
+ // (WebAssemblyFixFunctionBitcast) anymore
+ if (CGM.getTriple().isWasm() && DestType->isFunctionPointerType()) {
+ llvm::Function *Thunk =
+ CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
+ CGM, C, D->getType(), DestType);
+ if (Thunk) {
+ C = Thunk;
+ }
+ }
+
if (FD->getType()->isCFIUncheckedCalleeFunctionType())
C = llvm::NoCFIValue::get(cast<llvm::GlobalValue>(C));
return PtrAuthSign(C);
diff --git a/clang/lib/CodeGen/TargetInfo.h b/clang/lib/CodeGen/TargetInfo.h
index d0edae1295094..962178b2f549c 100644
--- a/clang/lib/CodeGen/TargetInfo.h
+++ b/clang/lib/CodeGen/TargetInfo.h
@@ -421,6 +421,17 @@ class TargetCodeGenInfo {
/// Return the WebAssembly funcref reference type.
virtual llvm::Type *getWasmFuncrefReferenceType() const { return nullptr; }
+ virtual const DeclRefExpr *getWasmFunctionDeclRefExpr(const Expr *E,
+ ASTContext &Ctx) const {
+ return nullptr;
+ }
+
+ virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
+ CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
+ QualType DstType) const {
+ return nullptr;
+ }
+
/// Emit the device-side copy of the builtin surface type.
virtual bool emitCUDADeviceBuiltinSurfaceDeviceCopy(CodeGenFunction &CGF,
LValue Dst,
diff --git a/clang/lib/CodeGen/Targets/WebAssembly.cpp b/clang/lib/CodeGen/Targets/WebAssembly.cpp
index ac8dcd2a0540a..4aa62948c4685 100644
--- a/clang/lib/CodeGen/Targets/WebAssembly.cpp
+++ b/clang/lib/CodeGen/Targets/WebAssembly.cpp
@@ -8,10 +8,15 @@
#include "ABIInfoImpl.h"
#include "TargetInfo.h"
+#include "llvm/ADT/StringMap.h"
+
+#include "clang/AST/ParentMapContext.h"
using namespace clang;
using namespace clang::CodeGen;
+#define DEBUG_TYPE "clang-target-wasm"
+
//===----------------------------------------------------------------------===//
// WebAssembly ABI Implementation
//
@@ -52,6 +57,7 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
: TargetCodeGenInfo(std::make_unique<WebAssemblyABIInfo>(CGT, K)) {
SwiftInfo =
std::make_unique<SwiftABIInfo>(CGT, /*SwiftErrorInRegister=*/false);
+ ThunkCache = llvm::StringMap<llvm::Function *>();
}
void setTargetAttributes(const Decl *D, llvm::GlobalValue *GV,
@@ -93,6 +99,127 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
virtual llvm::Type *getWasmFuncrefReferenceType() const override {
return llvm::Type::getWasm_FuncrefTy(getABIInfo().getVMContext());
}
+
+ virtual const DeclRefExpr *
+ getWasmFunctionDeclRefExpr(const Expr *E, ASTContext &Ctx) const override {
+ // Go down in the tree until finding the DeclRefExpr
+ const DeclRefExpr *DRE = findDeclRefExpr(E);
+ if (!DRE)
+ return nullptr;
+
+ // Final case. The argument is a declared function
+ if (isa<FunctionDecl>(DRE->getDecl())) {
+ return DRE;
+ }
+
+ // Complex case. The argument is a variable, we need to check
+ // every assignment of the variable and see if we are bitcasting
+ // or not.
+ if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
+ DRE = findDeclRefExprForVarUp(E, VD, Ctx);
+ if (DRE)
+ return DRE;
+
+ // If no assignment exists on every parent scope, check for the
+ // initialization
+ if (!DRE && VD->hasInit()) {
+ return getWasmFunctionDeclRefExpr(VD->getInit(), Ctx);
+ }
+ }
+
+ return nullptr;
+ }
+
+ virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
+ CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
+ QualType DstType) const override {
+
+ // Get the signatures
+ const FunctionProtoType *SrcProtoType = SrcType->getAs<FunctionProtoType>();
+ const FunctionProtoType *DstProtoType = DstType->getAs<PointerType>()
+ ->getPointeeType()
+ ->getAs<FunctionProtoType>();
+
+ // This should only work for different number of arguments
+ if (DstProtoType->getNumParams() <= SrcProtoType->getNumParams())
+ return nullptr;
+
+ // Get the llvm function types
+ llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
+ CGM.getTypes().ConvertType(QualType(DstProtoType, 0)));
+ llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
+ CGM.getTypes().ConvertType(QualType(SrcProtoType, 0)));
+
+ // Construct the Thunk function with the Target (destination) signature
+ std::string ThunkName = getThunkName(OriginalFnPtr->getName().str(),
+ DstProtoType, CGM.getContext());
+ // Check if we already have a thunk for this function
+ if (auto It = ThunkCache.find(ThunkName); It != ThunkCache.end()) {
+ LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk: "
+ << "found existing thunk for "
+ << OriginalFnPtr->getName().str() << " as "
+ << ThunkName << "\n");
+ return It->second;
+ }
+
+ // Create the thunk function
+ llvm::Module &M = CGM.getModule();
+ llvm::Function *Thunk = llvm::Function::Create(
+ DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);
+
+ // Build the thunk body
+ llvm::IRBuilder<> Builder(
+ llvm::BasicBlock::Create(M.getContext(), "entry", Thunk));
+
+ // Gather the arguments for calling the original function
+ std::vector<llvm::Value *> CallArgs;
+ unsigned CallN = SrcProtoType->getNumParams();
+
+ auto ArgIt = Thunk->arg_begin();
+ for (unsigned i = 0; i < CallN && ArgIt != Thunk->arg_end(); ++i, ++ArgIt) {
+ llvm::Value *A = &*ArgIt;
+ CallArgs.push_back(A);
+ }
+
+ // Create the call to the original function pointer
+ llvm::CallInst *Call =
+ Builder.CreateCall(SrcFunctionType, OriginalFnPtr, CallArgs);
+
+ // Handle return type
+ llvm::Type *ThunkRetTy = DstFunctionType->getReturnType();
+
+ if (ThunkRetTy->isVoidTy()) {
+ Builder.CreateRetVoid();
+ } else {
+ llvm::Value *Ret = Call;
+ if (Ret->getType() != ThunkRetTy)
+ Ret = Builder.CreateBitCast(Ret, ThunkRetTy);
+ Builder.CreateRet(Ret);
+ }
+ LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk:"
+ << " from " << OriginalFnPtr->getName().str()
+ << " to " << ThunkName << "\n");
+ // Cache the thunk
+ ThunkCache[ThunkName] = Thunk;
+ return Thunk;
+ }
+
+private:
+ // The thunk cache
+ mutable llvm::StringMap<llvm::Function *> ThunkCache;
+ // Build the thunk name: "%s_{OrigName}_{WasmSig}"
+ std::string getThunkName(std::string OrigName,
+ const FunctionProtoType *DstProto,
+ const ASTContext &Ctx) const;
+ char getTypeSig(const QualType &Ty, const ASTContext &Ctx) const;
+ std::string sanitizeTypeString(const std::string &typeStr) const;
+ std::string getTypeName(const QualType &qt, const ASTContext &Ctx) const;
+ const DeclRefExpr *findDeclRefExpr(const Expr *E) const;
+ const DeclRefExpr *findDeclRefExprForVarDown(const Stmt *Parent,
+ const VarDecl *V,
+ ASTContext &Ctx) const;
+ const DeclRefExpr *findDeclRefExprForVarUp(const Expr *E, const VarDecl *V,
+ ASTContext &Ctx) const;
};
/// Classify argument of given type \p Ty.
@@ -173,3 +300,120 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
WebAssemblyABIKind K) {
return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes(), K);
}
+
+// Helper to get the type signature character for a given QualType
+// Returns a character that represents the given QualType in a wasm signature.
+// See getInvokeSig() in WebAssemblyAsmPrinter for related logic.
+char WebAssemblyTargetCodeGenInfo::getTypeSig(const QualType &Ty,
+ const ASTContext &Ctx) const {
+ if (Ty->isAnyPointerType()) {
+ return Ctx.getTypeSize(Ctx.VoidPtrTy) == 32 ? 'i' : 'j';
+ }
+ if (Ty->isIntegerType()) {
+ return Ctx.getTypeSize(Ty) <= 32 ? 'i' : 'j';
+ }
+ if (Ty->isFloatingType()) {
+ return Ctx.getTypeSize(Ty) <= 32 ? 'f' : 'd';
+ }
+ if (Ty->isVectorType()) {
+ return 'V';
+ }
+ if (Ty->isWebAssemblyTableType()) {
+ return 'F';
+ }
+ if (Ty->isWebAssemblyExternrefType()) {
+ return 'X';
+ }
+
+ llvm_unreachable("Unhandled QualType");
+}
+
+std::string
+WebAssemblyTargetCodeGenInfo::getThunkName(std::string OrigName,
+ const FunctionProtoType *DstProto,
+ const ASTContext &Ctx) const {
+
+ std::string ThunkName = "__" + OrigName + "_";
+ QualType RetTy = DstProto->getReturnType();
+ if (RetTy->isVoidType()) {
+ ThunkName += 'v';
+ } else {
+ ThunkName += getTypeSig(RetTy, Ctx);
+ }
+ for (unsigned i = 0; i < DstProto->getNumParams(); ++i) {
+ ThunkName += getTypeSig(DstProto->getParamType(i), Ctx);
+ }
+ return ThunkName;
+}
+
+/// Recursively find the first DeclRefExpr in an Expr subtree.
+/// Returns nullptr if not found.
+const DeclRefExpr *
+WebAssemblyTargetCodeGenInfo::findDeclRefExpr(const Expr *E) const {
+ if (!E)
+ return nullptr;
+
+ // In case it is a function call, abort
+ if (isa<CallExpr>(E))
+ return nullptr;
+
+ // If this node is a DeclRefExpr, return it.
+ if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
+ return DRE;
+
+ // Otherwise, recurse into children.
+ for (const Stmt *Child : E->children()) {
+ if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
+ if (const DeclRefExpr *Found = findDeclRefExpr(ChildExpr))
+ return Found;
+ }
+ }
+ return nullptr;
+}
+
+const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown(
+ const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
+ if (!Parent)
+ return nullptr;
+
+ // Find down every assignment of V
+ // FIXME we need to stop before the expression where V is used
+ const BinaryOperator *A = nullptr;
+ for (const Stmt *Child : Parent->children()) {
+ if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
+ if (!BO->isAssignmentOp())
+ continue;
+ auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS());
+ if (LHS && LHS->getDecl() == V) {
+ A = BO;
+ }
+ }
+ }
+
+ // We have an assignment of the Var, recurse in it
+ if (A) {
+ return getWasmFunctionDeclRefExpr(A->getRHS(), Ctx);
+ }
+
+ return nullptr;
+}
+
+const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp(
+ const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
+ const clang::Stmt *cur = E;
+ while (cur) {
+ auto parents = Ctx.getParentMapContext().getParents(*cur);
+ if (parents.empty())
+ break;
+ const clang::Stmt *parentStmt = parents[0].get<clang::Stmt>();
+ if (!parentStmt)
+ break;
+ if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
+ const DeclRefExpr *DRE = findDeclRefExprForVarDown(CS, V, Ctx);
+ if (DRE)
+ return DRE;
+ }
+ cur = parentStmt;
+ }
+ return nullptr;
+}
diff --git a/clang/test/CodeGenWebAssembly/function-pointer-arg.c b/clang/test/CodeGenWebAssembly/function-pointer-arg.c
new file mode 100644
index 0000000000000..ff7b4186bbf7b
--- /dev/null
+++ b/clang/test/CodeGenWebAssembly/function-pointer-arg.c
@@ -0,0 +1,25 @@
+// REQUIRES: webassembly-registered-target
+// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s
+
+// Test of function pointer bitcast in a function argument with different argument number in wasm32
+
+#define FUNCTION_POINTER(f) ((FunctionPointer)(f))
+typedef int (*FunctionPointer)(int a, int b);
+
+int fp_as_arg(FunctionPointer fp, int a, int b) {
+ return fp(a, b);
+}
+
+int fp_less(int a) {
+ return a;
+}
+
+// CHECK-LABEL: @test
+// CHECK: call i32 @fp_as_arg(ptr noundef @__fp_less_iii, i32 noundef 10, i32 noundef 20)
+void test() {
+ fp_as_arg(FUNCTION_POINTER(fp_less), 10, 20);
+}
+
+// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1)
+// CHECK: %2 = call i32 @fp_less(i32 %0)
+// CHECK: ret i32 %2
\ No newline at end of file
diff --git a/clang/test/CodeGenWebAssembly/function-pointer-field.c b/clang/test/CodeGenWebAssembly/function-pointer-field.c
new file mode 100644
index 0000000000000..103a265ebf5fb
--- /dev/null
+++ b/clang/test/CodeGenWebAssembly/function-pointer-field.c
@@ -0,0 +1,30 @@
+// REQUIRES: webassembly-registered-target
+// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s
+
+// Test of function pointer bitcast in a struct field with different argument number in wasm32
+
+#define FUNCTION_POINTER(f) ((FunctionPointer)(f))
+typedef int (*FunctionPointer)(int a, int b);
+
+// CHECK: @__const.test.sfp = private unnamed_addr constant %struct._StructWithFunctionPointer { ptr @__fp_less_iii }, align 4
+
+typedef struct _StructWithFunctionPointer {
+ FunctionPointer fp;
+} StructWithFunctionPointer;
+
+int fp_less(int a) {
+ return a;
+}
+
+// CHECK-LABEL: @test
+void test() {
+ StructWithFunctionPointer sfp = {
+ FUNCTION_POINTER(fp_less)
+ };
+
+ int a1 = sfp.fp(10, 20);
+}
+
+// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1)
+// CHECK: %2 = call i32 @fp_less(i32 %0)
+// CHECK: ret i32 %2
|
Thanks for working on this. If I understand correctly, this will silently paper over function arg count mismatches in signature? That might help some code patterns, though there is a potential issue of silent code size bloat, if the developer did not intend to write a mismatching signature in the first place(?) (and if they didn't see it, Clang will then silently generate these thunks, which were not desired?) In our use case of #64526, I still wish there was a diagnostics feature that could be used like |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The C++ standard has exactly one rule for the result of function pointer casts: "converting a prvalue of type 'pointer to T1' to the type 'pointer to T2' (where T1 and T2 are function types) and back to its original type yields the original pointer value". This trivially violates that rule. (C has a similar rule.)
So we can't do this. At least, not by default.
Thanks @juj
I see. IIRC, clang already has a flag, not for wasm, but in general to detect bad signature casts, I saw that while investigating the code. But maybe it is missing in some specific use case? |
Thanks @efriedma-quic
I see. I wasn't aware of this rule. What would be the preferred way to handle this? With a specific flag, I guess? |
Maybe a flag. But I'd prefer to explore the design space first. What patterns do we care about? Can we detect all the relevant cases at compile-time? Can we generalize this by performing some checks at runtime? |
What I would like is a new flag exactly like the current I.e. currently Like was pointed out, of course detecting all scenarios that result to Wasm VM trapping at runtime is not possible (e.g. casting |
In our case, I am interested in getting warnings emitted. So the usual In this PR, the idea of transparently supporting the cast via a thunk was brought up, that might be something to consider under a |
I added a couple of tests that highly represent the main set of use cases that this PR solves. Using a cast for a function pointer struct field declaration, and when using a cast for call arguments. About other potential relevant cases, I guess there might be more, but for my own scenario with the PR, it is enough. |
Okay, I've re-read the issue you mentioned. Indeed, it makes sense to provide a diagnosis of such use cases. Maybe as part of another PR? This was more in the direction of a different number of arguments instead of a different type of arguments. |
This overcomes the limitation of WebAssemblyFixFunctionBitcasts LLVM module after the introduction of Opaque Pointers. The function pointers no longer have a prototype; therefore, the solution has to be done at clang. This solves many open issues, like: