Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/LangOptions.def
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,8 @@ LANGOPT(EnableLifetimeSafety, 1, 0, NotCompatible, "Experimental lifetime safety

LANGOPT(PreserveVec3Type, 1, 0, NotCompatible, "Preserve 3-component vector type")

LANGOPT(WasmFixFunctionBitcasts, 1, 0, Compatible, "Enable auto-generation of thunks for mismatched function pointer casts in WebAssembly")

#undef LANGOPT
#undef ENUM_LANGOPT
#undef VALUE_LANGOPT
6 changes: 6 additions & 0 deletions clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -9457,6 +9457,7 @@ def fvk_use_scalar_layout
: DXCFlag<"fvk-use-scalar-layout">,
HelpText<"Use scalar memory layout for Vulkan resources.">;

// WebAssembly-only Options
def no_wasm_opt : Flag<["--"], "no-wasm-opt">,
Group<m_Group>,
HelpText<"Disable the wasm-opt optimizer">,
Expand All @@ -9465,3 +9466,8 @@ def wasm_opt : Flag<["--"], "wasm-opt">,
Group<m_Group>,
HelpText<"Enable the wasm-opt optimizer (default)">,
MarshallingInfoNegativeFlag<LangOpts<"NoWasmOpt">>;
def fwasm_fix_function_bitcasts : Flag<["-"], "fwasm-fix-function-bitcasts">,
Group<f_Group>,
HelpText<"Enable auto-generation of thunks for mismatched function pointer casts in WebAssembly">,
Visibility<[ClangOption, CC1Option]>,
MarshallingInfoFlag<LangOpts<"WasmFixFunctionBitcasts">>;
21 changes: 21 additions & 0 deletions clang/lib/CodeGen/CGCall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4932,6 +4932,27 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
return;
}

// For WebAssembly target we need to create thunk functions
// to properly handle function pointers args with a different signature.
// Due to opaque pointers, this can not be handled in LLVM
// (WebAssemblyFixFunctionBitcast) anymore
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Talking about opaque pointers isn't really useful for a code comment, because it's not about the current state of the code.

if (CGM.getTriple().isWasm() && CGM.getLangOpts().WasmFixFunctionBitcasts &&
type->isFunctionPointerType()) {
if (const DeclRefExpr *DRE =
CGM.getTargetCodeGenInfo().getWasmFunctionDeclRefExpr(
E, CGM.getContext())) {
llvm::Value *V = EmitLValue(DRE).getPointer(*this);
llvm::Function *Thunk =
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
CGM, V, DRE->getDecl()->getType(), type);
if (Thunk) {
RValue R = RValue::get(Thunk);
args.add(R, type);
return;
}
}
}

args.add(EmitAnyExprToTemp(E), type);
}

Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/CGExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2243,6 +2243,21 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {

if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
llvm::Constant *C = CGM.getRawFunctionPointer(FD);
// ForWebAssembly target we need to create thunk functions
// to properly handle function pointers args with a different signature
// Due to opaque pointers, this can not be handled in LLVM
// (WebAssemblyFixFunctionBitcast) anymore
if (CGM.getTriple().isWasm() &&
CGM.getLangOpts().WasmFixFunctionBitcasts &&
DestType->isFunctionPointerType()) {
llvm::Function *Thunk =
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
CGM, C, D->getType(), DestType);
if (Thunk) {
C = Thunk;
}
}

if (FD->getType()->isCFIUncheckedCalleeFunctionType())
C = llvm::NoCFIValue::get(cast<llvm::GlobalValue>(C));
return PtrAuthSign(C);
Expand Down
11 changes: 11 additions & 0 deletions clang/lib/CodeGen/TargetInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,17 @@ class TargetCodeGenInfo {
/// Return the WebAssembly funcref reference type.
virtual llvm::Type *getWasmFuncrefReferenceType() const { return nullptr; }

virtual const DeclRefExpr *getWasmFunctionDeclRefExpr(const Expr *E,
ASTContext &Ctx) const {
return nullptr;
}

virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
QualType DstType) const {
return nullptr;
}

/// Emit the device-side copy of the builtin surface type.
virtual bool emitCUDADeviceBuiltinSurfaceDeviceCopy(CodeGenFunction &CGF,
LValue Dst,
Expand Down
244 changes: 244 additions & 0 deletions clang/lib/CodeGen/Targets/WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@

#include "ABIInfoImpl.h"
#include "TargetInfo.h"
#include "llvm/ADT/StringMap.h"

#include "clang/AST/ParentMapContext.h"

using namespace clang;
using namespace clang::CodeGen;

#define DEBUG_TYPE "clang-target-wasm"

//===----------------------------------------------------------------------===//
// WebAssembly ABI Implementation
//
Expand Down Expand Up @@ -52,6 +57,7 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
: TargetCodeGenInfo(std::make_unique<WebAssemblyABIInfo>(CGT, K)) {
SwiftInfo =
std::make_unique<SwiftABIInfo>(CGT, /*SwiftErrorInRegister=*/false);
ThunkCache = llvm::StringMap<llvm::Function *>();
}

void setTargetAttributes(const Decl *D, llvm::GlobalValue *GV,
Expand Down Expand Up @@ -93,6 +99,127 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
virtual llvm::Type *getWasmFuncrefReferenceType() const override {
return llvm::Type::getWasm_FuncrefTy(getABIInfo().getVMContext());
}

virtual const DeclRefExpr *
getWasmFunctionDeclRefExpr(const Expr *E, ASTContext &Ctx) const override {
// Go down in the tree until finding the DeclRefExpr
const DeclRefExpr *DRE = findDeclRefExpr(E);
if (!DRE)
return nullptr;

// Final case. The argument is a declared function
if (isa<FunctionDecl>(DRE->getDecl())) {
return DRE;
}

// Complex case. The argument is a variable, we need to check
// every assignment of the variable and see if we are bitcasting
// or not.
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
DRE = findDeclRefExprForVarUp(E, VD, Ctx);
if (DRE)
return DRE;

// If no assignment exists on every parent scope, check for the
// initialization
if (!DRE && VD->hasInit()) {
return getWasmFunctionDeclRefExpr(VD->getInit(), Ctx);
}
}

return nullptr;
}

virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
QualType DstType) const override {

// Get the signatures
const FunctionProtoType *SrcProtoType = SrcType->getAs<FunctionProtoType>();
const FunctionProtoType *DstProtoType = DstType->getAs<PointerType>()
->getPointeeType()
->getAs<FunctionProtoType>();

// This should only work for different number of arguments
if (DstProtoType->getNumParams() <= SrcProtoType->getNumParams())
return nullptr;

// Get the llvm function types
llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
CGM.getTypes().ConvertType(QualType(DstProtoType, 0)));
llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
CGM.getTypes().ConvertType(QualType(SrcProtoType, 0)));

// Construct the Thunk function with the Target (destination) signature
std::string ThunkName = getThunkName(OriginalFnPtr->getName().str(),
DstProtoType, CGM.getContext());
// Check if we already have a thunk for this function
if (auto It = ThunkCache.find(ThunkName); It != ThunkCache.end()) {
LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk: "
<< "found existing thunk for "
<< OriginalFnPtr->getName().str() << " as "
<< ThunkName << "\n");
return It->second;
}

// Create the thunk function
llvm::Module &M = CGM.getModule();
llvm::Function *Thunk = llvm::Function::Create(
DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);

// Build the thunk body
llvm::IRBuilder<> Builder(
llvm::BasicBlock::Create(M.getContext(), "entry", Thunk));

// Gather the arguments for calling the original function
std::vector<llvm::Value *> CallArgs;
unsigned CallN = SrcProtoType->getNumParams();

auto ArgIt = Thunk->arg_begin();
for (unsigned i = 0; i < CallN && ArgIt != Thunk->arg_end(); ++i, ++ArgIt) {
llvm::Value *A = &*ArgIt;
CallArgs.push_back(A);
}

// Create the call to the original function pointer
llvm::CallInst *Call =
Builder.CreateCall(SrcFunctionType, OriginalFnPtr, CallArgs);

// Handle return type
llvm::Type *ThunkRetTy = DstFunctionType->getReturnType();

if (ThunkRetTy->isVoidTy()) {
Builder.CreateRetVoid();
} else {
llvm::Value *Ret = Call;
if (Ret->getType() != ThunkRetTy)
Ret = Builder.CreateBitCast(Ret, ThunkRetTy);
Builder.CreateRet(Ret);
}
LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk:"
<< " from " << OriginalFnPtr->getName().str()
<< " to " << ThunkName << "\n");
// Cache the thunk
ThunkCache[ThunkName] = Thunk;
return Thunk;
}

private:
// The thunk cache
mutable llvm::StringMap<llvm::Function *> ThunkCache;
// Build the thunk name: "%s_{OrigName}_{WasmSig}"
std::string getThunkName(std::string OrigName,
const FunctionProtoType *DstProto,
const ASTContext &Ctx) const;
char getTypeSig(const QualType &Ty, const ASTContext &Ctx) const;
std::string sanitizeTypeString(const std::string &typeStr) const;
std::string getTypeName(const QualType &qt, const ASTContext &Ctx) const;
const DeclRefExpr *findDeclRefExpr(const Expr *E) const;
const DeclRefExpr *findDeclRefExprForVarDown(const Stmt *Parent,
const VarDecl *V,
ASTContext &Ctx) const;
const DeclRefExpr *findDeclRefExprForVarUp(const Expr *E, const VarDecl *V,
ASTContext &Ctx) const;
};

/// Classify argument of given type \p Ty.
Expand Down Expand Up @@ -173,3 +300,120 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
WebAssemblyABIKind K) {
return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes(), K);
}

// Helper to get the type signature character for a given QualType
// Returns a character that represents the given QualType in a wasm signature.
// See getInvokeSig() in WebAssemblyAsmPrinter for related logic.
char WebAssemblyTargetCodeGenInfo::getTypeSig(const QualType &Ty,
const ASTContext &Ctx) const {
if (Ty->isAnyPointerType()) {
return Ctx.getTypeSize(Ctx.VoidPtrTy) == 32 ? 'i' : 'j';
}
if (Ty->isIntegerType()) {
return Ctx.getTypeSize(Ty) <= 32 ? 'i' : 'j';
}
if (Ty->isFloatingType()) {
return Ctx.getTypeSize(Ty) <= 32 ? 'f' : 'd';
}
if (Ty->isVectorType()) {
return 'V';
}
if (Ty->isWebAssemblyTableType()) {
return 'F';
}
if (Ty->isWebAssemblyExternrefType()) {
return 'X';
}

llvm_unreachable("Unhandled QualType");
}

std::string
WebAssemblyTargetCodeGenInfo::getThunkName(std::string OrigName,
const FunctionProtoType *DstProto,
const ASTContext &Ctx) const {

std::string ThunkName = "__" + OrigName + "_";
QualType RetTy = DstProto->getReturnType();
if (RetTy->isVoidType()) {
ThunkName += 'v';
} else {
ThunkName += getTypeSig(RetTy, Ctx);
}
for (unsigned i = 0; i < DstProto->getNumParams(); ++i) {
ThunkName += getTypeSig(DstProto->getParamType(i), Ctx);
}
return ThunkName;
}

/// Recursively find the first DeclRefExpr in an Expr subtree.
/// Returns nullptr if not found.
const DeclRefExpr *
WebAssemblyTargetCodeGenInfo::findDeclRefExpr(const Expr *E) const {
if (!E)
return nullptr;

// In case it is a function call, abort
if (isa<CallExpr>(E))
return nullptr;

// If this node is a DeclRefExpr, return it.
if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
return DRE;

// Otherwise, recurse into children.
for (const Stmt *Child : E->children()) {
if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
if (const DeclRefExpr *Found = findDeclRefExpr(ChildExpr))
return Found;
}
}
return nullptr;
}

const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown(
const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
if (!Parent)
return nullptr;

// Find down every assignment of V
// FIXME we need to stop before the expression where V is used
const BinaryOperator *A = nullptr;
for (const Stmt *Child : Parent->children()) {
if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
if (!BO->isAssignmentOp())
continue;
auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS());
if (LHS && LHS->getDecl() == V) {
A = BO;
}
}
}

// We have an assignment of the Var, recurse in it
if (A) {
return getWasmFunctionDeclRefExpr(A->getRHS(), Ctx);
}

return nullptr;
}

const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp(
const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
const clang::Stmt *cur = E;
while (cur) {
auto parents = Ctx.getParentMapContext().getParents(*cur);
if (parents.empty())
break;
const clang::Stmt *parentStmt = parents[0].get<clang::Stmt>();
if (!parentStmt)
break;
if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
const DeclRefExpr *DRE = findDeclRefExprForVarDown(CS, V, Ctx);
if (DRE)
return DRE;
}
cur = parentStmt;
}
return nullptr;
}
4 changes: 4 additions & 0 deletions clang/lib/Driver/ToolChains/WebAssembly.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ void WebAssembly::addClangTargetOptions(const ArgList &DriverArgs,
CC1Args.push_back("-wasm-enable-eh");
}

if (DriverArgs.getLastArg(options::OPT_fwasm_fix_function_bitcasts)) {
CC1Args.push_back("-fwasm-fix-function-bitcasts");
}

for (const Arg *A : DriverArgs.filtered(options::OPT_mllvm)) {
StringRef Opt = A->getValue(0);
if (Opt.starts_with("-emscripten-cxx-exceptions-allowed")) {
Expand Down
Loading