diff --git a/clang/lib/CodeGen/CGCall.cpp b/clang/lib/CodeGen/CGCall.cpp index b959982809911..8eb3afee08ff5 100644 --- a/clang/lib/CodeGen/CGCall.cpp +++ b/clang/lib/CodeGen/CGCall.cpp @@ -4932,6 +4932,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E, return; } + // For WebAssembly target we need to create thunk functions + // to properly handle function pointers args with a different signature. + // Due to opaque pointers, this can not be handled in LLVM + // (WebAssemblyFixFunctionBitcast) anymore + if (CGM.getTriple().isWasm() && type->isFunctionPointerType()) { + if (const DeclRefExpr *DRE = + CGM.getTargetCodeGenInfo().getWasmFunctionDeclRefExpr( + E, CGM.getContext())) { + llvm::Value *V = EmitLValue(DRE).getPointer(*this); + llvm::Function *Thunk = + CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk( + CGM, V, DRE->getDecl()->getType(), type); + if (Thunk) { + RValue R = RValue::get(Thunk); + args.add(R, type); + return; + } + } + } + args.add(EmitAnyExprToTemp(E), type); } diff --git a/clang/lib/CodeGen/CGExprConstant.cpp b/clang/lib/CodeGen/CGExprConstant.cpp index a96c1518d2a1d..956f780ad71e3 100644 --- a/clang/lib/CodeGen/CGExprConstant.cpp +++ b/clang/lib/CodeGen/CGExprConstant.cpp @@ -2243,6 +2243,19 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) { if (const auto *FD = dyn_cast(D)) { llvm::Constant *C = CGM.getRawFunctionPointer(FD); + // ForWebAssembly target we need to create thunk functions + // to properly handle function pointers args with a different signature + // Due to opaque pointers, this can not be handled in LLVM + // (WebAssemblyFixFunctionBitcast) anymore + if (CGM.getTriple().isWasm() && DestType->isFunctionPointerType()) { + llvm::Function *Thunk = + CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk( + CGM, C, D->getType(), DestType); + if (Thunk) { + C = Thunk; + } + } + if (FD->getType()->isCFIUncheckedCalleeFunctionType()) C = llvm::NoCFIValue::get(cast(C)); return PtrAuthSign(C); diff --git a/clang/lib/CodeGen/TargetInfo.h b/clang/lib/CodeGen/TargetInfo.h index d0edae1295094..962178b2f549c 100644 --- a/clang/lib/CodeGen/TargetInfo.h +++ b/clang/lib/CodeGen/TargetInfo.h @@ -421,6 +421,17 @@ class TargetCodeGenInfo { /// Return the WebAssembly funcref reference type. virtual llvm::Type *getWasmFuncrefReferenceType() const { return nullptr; } + virtual const DeclRefExpr *getWasmFunctionDeclRefExpr(const Expr *E, + ASTContext &Ctx) const { + return nullptr; + } + + virtual llvm::Function *getOrCreateWasmFunctionPointerThunk( + CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType, + QualType DstType) const { + return nullptr; + } + /// Emit the device-side copy of the builtin surface type. virtual bool emitCUDADeviceBuiltinSurfaceDeviceCopy(CodeGenFunction &CGF, LValue Dst, diff --git a/clang/lib/CodeGen/Targets/WebAssembly.cpp b/clang/lib/CodeGen/Targets/WebAssembly.cpp index ac8dcd2a0540a..4aa62948c4685 100644 --- a/clang/lib/CodeGen/Targets/WebAssembly.cpp +++ b/clang/lib/CodeGen/Targets/WebAssembly.cpp @@ -8,10 +8,15 @@ #include "ABIInfoImpl.h" #include "TargetInfo.h" +#include "llvm/ADT/StringMap.h" + +#include "clang/AST/ParentMapContext.h" using namespace clang; using namespace clang::CodeGen; +#define DEBUG_TYPE "clang-target-wasm" + //===----------------------------------------------------------------------===// // WebAssembly ABI Implementation // @@ -52,6 +57,7 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo { : TargetCodeGenInfo(std::make_unique(CGT, K)) { SwiftInfo = std::make_unique(CGT, /*SwiftErrorInRegister=*/false); + ThunkCache = llvm::StringMap(); } void setTargetAttributes(const Decl *D, llvm::GlobalValue *GV, @@ -93,6 +99,127 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo { virtual llvm::Type *getWasmFuncrefReferenceType() const override { return llvm::Type::getWasm_FuncrefTy(getABIInfo().getVMContext()); } + + virtual const DeclRefExpr * + getWasmFunctionDeclRefExpr(const Expr *E, ASTContext &Ctx) const override { + // Go down in the tree until finding the DeclRefExpr + const DeclRefExpr *DRE = findDeclRefExpr(E); + if (!DRE) + return nullptr; + + // Final case. The argument is a declared function + if (isa(DRE->getDecl())) { + return DRE; + } + + // Complex case. The argument is a variable, we need to check + // every assignment of the variable and see if we are bitcasting + // or not. + if (const auto *VD = dyn_cast(DRE->getDecl())) { + DRE = findDeclRefExprForVarUp(E, VD, Ctx); + if (DRE) + return DRE; + + // If no assignment exists on every parent scope, check for the + // initialization + if (!DRE && VD->hasInit()) { + return getWasmFunctionDeclRefExpr(VD->getInit(), Ctx); + } + } + + return nullptr; + } + + virtual llvm::Function *getOrCreateWasmFunctionPointerThunk( + CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType, + QualType DstType) const override { + + // Get the signatures + const FunctionProtoType *SrcProtoType = SrcType->getAs(); + const FunctionProtoType *DstProtoType = DstType->getAs() + ->getPointeeType() + ->getAs(); + + // This should only work for different number of arguments + if (DstProtoType->getNumParams() <= SrcProtoType->getNumParams()) + return nullptr; + + // Get the llvm function types + llvm::FunctionType *DstFunctionType = llvm::cast( + CGM.getTypes().ConvertType(QualType(DstProtoType, 0))); + llvm::FunctionType *SrcFunctionType = llvm::cast( + CGM.getTypes().ConvertType(QualType(SrcProtoType, 0))); + + // Construct the Thunk function with the Target (destination) signature + std::string ThunkName = getThunkName(OriginalFnPtr->getName().str(), + DstProtoType, CGM.getContext()); + // Check if we already have a thunk for this function + if (auto It = ThunkCache.find(ThunkName); It != ThunkCache.end()) { + LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk: " + << "found existing thunk for " + << OriginalFnPtr->getName().str() << " as " + << ThunkName << "\n"); + return It->second; + } + + // Create the thunk function + llvm::Module &M = CGM.getModule(); + llvm::Function *Thunk = llvm::Function::Create( + DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M); + + // Build the thunk body + llvm::IRBuilder<> Builder( + llvm::BasicBlock::Create(M.getContext(), "entry", Thunk)); + + // Gather the arguments for calling the original function + std::vector CallArgs; + unsigned CallN = SrcProtoType->getNumParams(); + + auto ArgIt = Thunk->arg_begin(); + for (unsigned i = 0; i < CallN && ArgIt != Thunk->arg_end(); ++i, ++ArgIt) { + llvm::Value *A = &*ArgIt; + CallArgs.push_back(A); + } + + // Create the call to the original function pointer + llvm::CallInst *Call = + Builder.CreateCall(SrcFunctionType, OriginalFnPtr, CallArgs); + + // Handle return type + llvm::Type *ThunkRetTy = DstFunctionType->getReturnType(); + + if (ThunkRetTy->isVoidTy()) { + Builder.CreateRetVoid(); + } else { + llvm::Value *Ret = Call; + if (Ret->getType() != ThunkRetTy) + Ret = Builder.CreateBitCast(Ret, ThunkRetTy); + Builder.CreateRet(Ret); + } + LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk:" + << " from " << OriginalFnPtr->getName().str() + << " to " << ThunkName << "\n"); + // Cache the thunk + ThunkCache[ThunkName] = Thunk; + return Thunk; + } + +private: + // The thunk cache + mutable llvm::StringMap ThunkCache; + // Build the thunk name: "%s_{OrigName}_{WasmSig}" + std::string getThunkName(std::string OrigName, + const FunctionProtoType *DstProto, + const ASTContext &Ctx) const; + char getTypeSig(const QualType &Ty, const ASTContext &Ctx) const; + std::string sanitizeTypeString(const std::string &typeStr) const; + std::string getTypeName(const QualType &qt, const ASTContext &Ctx) const; + const DeclRefExpr *findDeclRefExpr(const Expr *E) const; + const DeclRefExpr *findDeclRefExprForVarDown(const Stmt *Parent, + const VarDecl *V, + ASTContext &Ctx) const; + const DeclRefExpr *findDeclRefExprForVarUp(const Expr *E, const VarDecl *V, + ASTContext &Ctx) const; }; /// Classify argument of given type \p Ty. @@ -173,3 +300,120 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM, WebAssemblyABIKind K) { return std::make_unique(CGM.getTypes(), K); } + +// Helper to get the type signature character for a given QualType +// Returns a character that represents the given QualType in a wasm signature. +// See getInvokeSig() in WebAssemblyAsmPrinter for related logic. +char WebAssemblyTargetCodeGenInfo::getTypeSig(const QualType &Ty, + const ASTContext &Ctx) const { + if (Ty->isAnyPointerType()) { + return Ctx.getTypeSize(Ctx.VoidPtrTy) == 32 ? 'i' : 'j'; + } + if (Ty->isIntegerType()) { + return Ctx.getTypeSize(Ty) <= 32 ? 'i' : 'j'; + } + if (Ty->isFloatingType()) { + return Ctx.getTypeSize(Ty) <= 32 ? 'f' : 'd'; + } + if (Ty->isVectorType()) { + return 'V'; + } + if (Ty->isWebAssemblyTableType()) { + return 'F'; + } + if (Ty->isWebAssemblyExternrefType()) { + return 'X'; + } + + llvm_unreachable("Unhandled QualType"); +} + +std::string +WebAssemblyTargetCodeGenInfo::getThunkName(std::string OrigName, + const FunctionProtoType *DstProto, + const ASTContext &Ctx) const { + + std::string ThunkName = "__" + OrigName + "_"; + QualType RetTy = DstProto->getReturnType(); + if (RetTy->isVoidType()) { + ThunkName += 'v'; + } else { + ThunkName += getTypeSig(RetTy, Ctx); + } + for (unsigned i = 0; i < DstProto->getNumParams(); ++i) { + ThunkName += getTypeSig(DstProto->getParamType(i), Ctx); + } + return ThunkName; +} + +/// Recursively find the first DeclRefExpr in an Expr subtree. +/// Returns nullptr if not found. +const DeclRefExpr * +WebAssemblyTargetCodeGenInfo::findDeclRefExpr(const Expr *E) const { + if (!E) + return nullptr; + + // In case it is a function call, abort + if (isa(E)) + return nullptr; + + // If this node is a DeclRefExpr, return it. + if (const auto *DRE = dyn_cast(E)) + return DRE; + + // Otherwise, recurse into children. + for (const Stmt *Child : E->children()) { + if (const auto *ChildExpr = dyn_cast_or_null(Child)) { + if (const DeclRefExpr *Found = findDeclRefExpr(ChildExpr)) + return Found; + } + } + return nullptr; +} + +const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown( + const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const { + if (!Parent) + return nullptr; + + // Find down every assignment of V + // FIXME we need to stop before the expression where V is used + const BinaryOperator *A = nullptr; + for (const Stmt *Child : Parent->children()) { + if (const auto *BO = dyn_cast_or_null(Child)) { + if (!BO->isAssignmentOp()) + continue; + auto *LHS = llvm::dyn_cast(BO->getLHS()); + if (LHS && LHS->getDecl() == V) { + A = BO; + } + } + } + + // We have an assignment of the Var, recurse in it + if (A) { + return getWasmFunctionDeclRefExpr(A->getRHS(), Ctx); + } + + return nullptr; +} + +const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp( + const Expr *E, const VarDecl *V, ASTContext &Ctx) const { + const clang::Stmt *cur = E; + while (cur) { + auto parents = Ctx.getParentMapContext().getParents(*cur); + if (parents.empty()) + break; + const clang::Stmt *parentStmt = parents[0].get(); + if (!parentStmt) + break; + if (const auto *CS = dyn_cast(parentStmt)) { + const DeclRefExpr *DRE = findDeclRefExprForVarDown(CS, V, Ctx); + if (DRE) + return DRE; + } + cur = parentStmt; + } + return nullptr; +} diff --git a/clang/test/CodeGenWebAssembly/function-pointer-arg.c b/clang/test/CodeGenWebAssembly/function-pointer-arg.c new file mode 100644 index 0000000000000..ff7b4186bbf7b --- /dev/null +++ b/clang/test/CodeGenWebAssembly/function-pointer-arg.c @@ -0,0 +1,25 @@ +// REQUIRES: webassembly-registered-target +// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s + +// Test of function pointer bitcast in a function argument with different argument number in wasm32 + +#define FUNCTION_POINTER(f) ((FunctionPointer)(f)) +typedef int (*FunctionPointer)(int a, int b); + +int fp_as_arg(FunctionPointer fp, int a, int b) { + return fp(a, b); +} + +int fp_less(int a) { + return a; +} + +// CHECK-LABEL: @test +// CHECK: call i32 @fp_as_arg(ptr noundef @__fp_less_iii, i32 noundef 10, i32 noundef 20) +void test() { + fp_as_arg(FUNCTION_POINTER(fp_less), 10, 20); +} + +// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1) +// CHECK: %2 = call i32 @fp_less(i32 %0) +// CHECK: ret i32 %2 \ No newline at end of file diff --git a/clang/test/CodeGenWebAssembly/function-pointer-field.c b/clang/test/CodeGenWebAssembly/function-pointer-field.c new file mode 100644 index 0000000000000..103a265ebf5fb --- /dev/null +++ b/clang/test/CodeGenWebAssembly/function-pointer-field.c @@ -0,0 +1,30 @@ +// REQUIRES: webassembly-registered-target +// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s + +// Test of function pointer bitcast in a struct field with different argument number in wasm32 + +#define FUNCTION_POINTER(f) ((FunctionPointer)(f)) +typedef int (*FunctionPointer)(int a, int b); + +// CHECK: @__const.test.sfp = private unnamed_addr constant %struct._StructWithFunctionPointer { ptr @__fp_less_iii }, align 4 + +typedef struct _StructWithFunctionPointer { + FunctionPointer fp; +} StructWithFunctionPointer; + +int fp_less(int a) { + return a; +} + +// CHECK-LABEL: @test +void test() { + StructWithFunctionPointer sfp = { + FUNCTION_POINTER(fp_less) + }; + + int a1 = sfp.fp(10, 20); +} + +// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1) +// CHECK: %2 = call i32 @fp_less(i32 %0) +// CHECK: ret i32 %2