Skip to content

[clang][WebAssembly] Handle casted function pointers with different number of arguments #153168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4932,6 +4932,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
return;
}

// For WebAssembly target we need to create thunk functions
// to properly handle function pointers args with a different signature.
// Due to opaque pointers, this can not be handled in LLVM
// (WebAssemblyFixFunctionBitcast) anymore
if (CGM.getTriple().isWasm() && type->isFunctionPointerType()) {
if (const DeclRefExpr *DRE =
CGM.getTargetCodeGenInfo().getWasmFunctionDeclRefExpr(
E, CGM.getContext())) {
llvm::Value *V = EmitLValue(DRE).getPointer(*this);
llvm::Function *Thunk =
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
CGM, V, DRE->getDecl()->getType(), type);
if (Thunk) {
RValue R = RValue::get(Thunk);
args.add(R, type);
return;
}
}
}

args.add(EmitAnyExprToTemp(E), type);
}

Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,19 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {

if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
llvm::Constant *C = CGM.getRawFunctionPointer(FD);
// ForWebAssembly target we need to create thunk functions
// to properly handle function pointers args with a different signature
// Due to opaque pointers, this can not be handled in LLVM
// (WebAssemblyFixFunctionBitcast) anymore
if (CGM.getTriple().isWasm() && DestType->isFunctionPointerType()) {
llvm::Function *Thunk =
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
CGM, C, D->getType(), DestType);
if (Thunk) {
C = Thunk;
}
}

if (FD->getType()->isCFIUncheckedCalleeFunctionType())
C = llvm::NoCFIValue::get(cast<llvm::GlobalValue>(C));
return PtrAuthSign(C);
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ class TargetCodeGenInfo {
/// Return the WebAssembly funcref reference type.
virtual llvm::Type *getWasmFuncrefReferenceType() const { return nullptr; }

virtual const DeclRefExpr *getWasmFunctionDeclRefExpr(const Expr *E,
ASTContext &Ctx) const {
return nullptr;
}

virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
QualType DstType) const {
return nullptr;
}

/// Emit the device-side copy of the builtin surface type.
virtual bool emitCUDADeviceBuiltinSurfaceDeviceCopy(CodeGenFunction &CGF,
LValue Dst,
Expand Down
244 changes: 244 additions & 0 deletions clang/lib/CodeGen/Targets/WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

#include "ABIInfoImpl.h"
#include "TargetInfo.h"
#include "llvm/ADT/StringMap.h"

#include "clang/AST/ParentMapContext.h"

using namespace clang;
using namespace clang::CodeGen;

#define DEBUG_TYPE "clang-target-wasm"

//===----------------------------------------------------------------------===//
// WebAssembly ABI Implementation
//
Expand Down Expand Up @@ -52,6 +57,7 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
: TargetCodeGenInfo(std::make_unique<WebAssemblyABIInfo>(CGT, K)) {
SwiftInfo =
std::make_unique<SwiftABIInfo>(CGT, /*SwiftErrorInRegister=*/false);
ThunkCache = llvm::StringMap<llvm::Function *>();
}

void setTargetAttributes(const Decl *D, llvm::GlobalValue *GV,
Expand Down Expand Up @@ -93,6 +99,127 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
virtual llvm::Type *getWasmFuncrefReferenceType() const override {
return llvm::Type::getWasm_FuncrefTy(getABIInfo().getVMContext());
}

virtual const DeclRefExpr *
getWasmFunctionDeclRefExpr(const Expr *E, ASTContext &Ctx) const override {
// Go down in the tree until finding the DeclRefExpr
const DeclRefExpr *DRE = findDeclRefExpr(E);
if (!DRE)
return nullptr;

// Final case. The argument is a declared function
if (isa<FunctionDecl>(DRE->getDecl())) {
return DRE;
}

// Complex case. The argument is a variable, we need to check
// every assignment of the variable and see if we are bitcasting
// or not.
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
DRE = findDeclRefExprForVarUp(E, VD, Ctx);
if (DRE)
return DRE;

// If no assignment exists on every parent scope, check for the
// initialization
if (!DRE && VD->hasInit()) {
return getWasmFunctionDeclRefExpr(VD->getInit(), Ctx);
}
}

return nullptr;
}

virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
QualType DstType) const override {

// Get the signatures
const FunctionProtoType *SrcProtoType = SrcType->getAs<FunctionProtoType>();
const FunctionProtoType *DstProtoType = DstType->getAs<PointerType>()
->getPointeeType()
->getAs<FunctionProtoType>();

// This should only work for different number of arguments
if (DstProtoType->getNumParams() <= SrcProtoType->getNumParams())
return nullptr;

// Get the llvm function types
llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
CGM.getTypes().ConvertType(QualType(DstProtoType, 0)));
llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
CGM.getTypes().ConvertType(QualType(SrcProtoType, 0)));

// Construct the Thunk function with the Target (destination) signature
std::string ThunkName = getThunkName(OriginalFnPtr->getName().str(),
DstProtoType, CGM.getContext());
// Check if we already have a thunk for this function
if (auto It = ThunkCache.find(ThunkName); It != ThunkCache.end()) {
LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk: "
<< "found existing thunk for "
<< OriginalFnPtr->getName().str() << " as "
<< ThunkName << "\n");
return It->second;
}

// Create the thunk function
llvm::Module &M = CGM.getModule();
llvm::Function *Thunk = llvm::Function::Create(
DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);

// Build the thunk body
llvm::IRBuilder<> Builder(
llvm::BasicBlock::Create(M.getContext(), "entry", Thunk));

// Gather the arguments for calling the original function
std::vector<llvm::Value *> CallArgs;
unsigned CallN = SrcProtoType->getNumParams();

auto ArgIt = Thunk->arg_begin();
for (unsigned i = 0; i < CallN && ArgIt != Thunk->arg_end(); ++i, ++ArgIt) {
llvm::Value *A = &*ArgIt;
CallArgs.push_back(A);
}

// Create the call to the original function pointer
llvm::CallInst *Call =
Builder.CreateCall(SrcFunctionType, OriginalFnPtr, CallArgs);

// Handle return type
llvm::Type *ThunkRetTy = DstFunctionType->getReturnType();

if (ThunkRetTy->isVoidTy()) {
Builder.CreateRetVoid();
} else {
llvm::Value *Ret = Call;
if (Ret->getType() != ThunkRetTy)
Ret = Builder.CreateBitCast(Ret, ThunkRetTy);
Builder.CreateRet(Ret);
}
LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk:"
<< " from " << OriginalFnPtr->getName().str()
<< " to " << ThunkName << "\n");
// Cache the thunk
ThunkCache[ThunkName] = Thunk;
return Thunk;
}

private:
// The thunk cache
mutable llvm::StringMap<llvm::Function *> ThunkCache;
// Build the thunk name: "%s_{OrigName}_{WasmSig}"
std::string getThunkName(std::string OrigName,
const FunctionProtoType *DstProto,
const ASTContext &Ctx) const;
char getTypeSig(const QualType &Ty, const ASTContext &Ctx) const;
std::string sanitizeTypeString(const std::string &typeStr) const;
std::string getTypeName(const QualType &qt, const ASTContext &Ctx) const;
const DeclRefExpr *findDeclRefExpr(const Expr *E) const;
const DeclRefExpr *findDeclRefExprForVarDown(const Stmt *Parent,
const VarDecl *V,
ASTContext &Ctx) const;
const DeclRefExpr *findDeclRefExprForVarUp(const Expr *E, const VarDecl *V,
ASTContext &Ctx) const;
};

/// Classify argument of given type \p Ty.
Expand Down Expand Up @@ -173,3 +300,120 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
WebAssemblyABIKind K) {
return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes(), K);
}

// Helper to get the type signature character for a given QualType
// Returns a character that represents the given QualType in a wasm signature.
// See getInvokeSig() in WebAssemblyAsmPrinter for related logic.
char WebAssemblyTargetCodeGenInfo::getTypeSig(const QualType &Ty,
const ASTContext &Ctx) const {
if (Ty->isAnyPointerType()) {
return Ctx.getTypeSize(Ctx.VoidPtrTy) == 32 ? 'i' : 'j';
}
if (Ty->isIntegerType()) {
return Ctx.getTypeSize(Ty) <= 32 ? 'i' : 'j';
}
if (Ty->isFloatingType()) {
return Ctx.getTypeSize(Ty) <= 32 ? 'f' : 'd';
}
if (Ty->isVectorType()) {
return 'V';
}
if (Ty->isWebAssemblyTableType()) {
return 'F';
}
if (Ty->isWebAssemblyExternrefType()) {
return 'X';
}

llvm_unreachable("Unhandled QualType");
}

std::string
WebAssemblyTargetCodeGenInfo::getThunkName(std::string OrigName,
const FunctionProtoType *DstProto,
const ASTContext &Ctx) const {

std::string ThunkName = "__" + OrigName + "_";
QualType RetTy = DstProto->getReturnType();
if (RetTy->isVoidType()) {
ThunkName += 'v';
} else {
ThunkName += getTypeSig(RetTy, Ctx);
}
for (unsigned i = 0; i < DstProto->getNumParams(); ++i) {
ThunkName += getTypeSig(DstProto->getParamType(i), Ctx);
}
return ThunkName;
}

/// Recursively find the first DeclRefExpr in an Expr subtree.
/// Returns nullptr if not found.
const DeclRefExpr *
WebAssemblyTargetCodeGenInfo::findDeclRefExpr(const Expr *E) const {
if (!E)
return nullptr;

// In case it is a function call, abort
if (isa<CallExpr>(E))
return nullptr;

// If this node is a DeclRefExpr, return it.
if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
return DRE;

// Otherwise, recurse into children.
for (const Stmt *Child : E->children()) {
if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
if (const DeclRefExpr *Found = findDeclRefExpr(ChildExpr))
return Found;
}
}
return nullptr;
}

const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown(
const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
if (!Parent)
return nullptr;

// Find down every assignment of V
// FIXME we need to stop before the expression where V is used
const BinaryOperator *A = nullptr;
for (const Stmt *Child : Parent->children()) {
if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
if (!BO->isAssignmentOp())
continue;
auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS());
if (LHS && LHS->getDecl() == V) {
A = BO;
}
}
}

// We have an assignment of the Var, recurse in it
if (A) {
return getWasmFunctionDeclRefExpr(A->getRHS(), Ctx);
}

return nullptr;
}

const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp(
const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
const clang::Stmt *cur = E;
while (cur) {
auto parents = Ctx.getParentMapContext().getParents(*cur);
if (parents.empty())
break;
const clang::Stmt *parentStmt = parents[0].get<clang::Stmt>();
if (!parentStmt)
break;
if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
const DeclRefExpr *DRE = findDeclRefExprForVarDown(CS, V, Ctx);
if (DRE)
return DRE;
}
cur = parentStmt;
}
return nullptr;
}
25 changes: 25 additions & 0 deletions clang/test/CodeGenWebAssembly/function-pointer-arg.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// REQUIRES: webassembly-registered-target
// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s

// Test of function pointer bitcast in a function argument with different argument number in wasm32

#define FUNCTION_POINTER(f) ((FunctionPointer)(f))
typedef int (*FunctionPointer)(int a, int b);

int fp_as_arg(FunctionPointer fp, int a, int b) {
return fp(a, b);
}

int fp_less(int a) {
return a;
}

// CHECK-LABEL: @test
// CHECK: call i32 @fp_as_arg(ptr noundef @__fp_less_iii, i32 noundef 10, i32 noundef 20)
void test() {
fp_as_arg(FUNCTION_POINTER(fp_less), 10, 20);
}

// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1)
// CHECK: %2 = call i32 @fp_less(i32 %0)
// CHECK: ret i32 %2
30 changes: 30 additions & 0 deletions clang/test/CodeGenWebAssembly/function-pointer-field.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// REQUIRES: webassembly-registered-target
// RUN: %clang_cc1 -triple wasm32-unknown-unknown -emit-llvm -O0 -o - %s | FileCheck %s

// Test of function pointer bitcast in a struct field with different argument number in wasm32

#define FUNCTION_POINTER(f) ((FunctionPointer)(f))
typedef int (*FunctionPointer)(int a, int b);

// CHECK: @__const.test.sfp = private unnamed_addr constant %struct._StructWithFunctionPointer { ptr @__fp_less_iii }, align 4

typedef struct _StructWithFunctionPointer {
FunctionPointer fp;
} StructWithFunctionPointer;

int fp_less(int a) {
return a;
}

// CHECK-LABEL: @test
void test() {
StructWithFunctionPointer sfp = {
FUNCTION_POINTER(fp_less)
};

int a1 = sfp.fp(10, 20);
}

// CHECK: define internal i32 @__fp_less_iii(i32 %0, i32 %1)
// CHECK: %2 = call i32 @fp_less(i32 %0)
// CHECK: ret i32 %2