Skip to content

Commit 4836489

Browse files
committed
[wasm] Support different signature function pointers
1 parent 72b53cd commit 4836489

File tree

4 files changed

+266
-0
lines changed

4 files changed

+266
-0
lines changed

clang/lib/CodeGen/CGCall.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4932,6 +4932,26 @@ void CodeGenFunction::EmitCallArg(CallArgList &args, const Expr *E,
49324932
return;
49334933
}
49344934

4935+
// For WebAssembly target we need to create thunk functions
4936+
// to properly handle function pointers args with a different signature.
4937+
// Due to opaque pointers, this can not be handled in LLVM
4938+
// (WebAssemblyFixFunctionBitcast) anymore
4939+
if (CGM.getTriple().isWasm() && type->isFunctionPointerType()) {
4940+
if (const DeclRefExpr *DRE =
4941+
CGM.getTargetCodeGenInfo().getWasmFunctionDeclRefExpr(
4942+
E, CGM.getContext())) {
4943+
llvm::Value *V = EmitLValue(DRE).getPointer(*this);
4944+
llvm::Function *Thunk =
4945+
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
4946+
CGM, V, DRE->getDecl()->getType(), type);
4947+
if (Thunk) {
4948+
RValue R = RValue::get(Thunk);
4949+
args.add(R, type);
4950+
return;
4951+
}
4952+
}
4953+
}
4954+
49354955
args.add(EmitAnyExprToTemp(E), type);
49364956
}
49374957

clang/lib/CodeGen/CGExprConstant.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2243,6 +2243,19 @@ ConstantLValueEmitter::tryEmitBase(const APValue::LValueBase &base) {
22432243

22442244
if (const auto *FD = dyn_cast<FunctionDecl>(D)) {
22452245
llvm::Constant *C = CGM.getRawFunctionPointer(FD);
2246+
// ForWebAssembly target we need to create thunk functions
2247+
// to properly handle function pointers args with a different signature
2248+
// Due to opaque pointers, this can not be handled in LLVM
2249+
// (WebAssemblyFixFunctionBitcast) anymore
2250+
if (CGM.getTriple().isWasm() && DestType->isFunctionPointerType()) {
2251+
llvm::Function *Thunk =
2252+
CGM.getTargetCodeGenInfo().getOrCreateWasmFunctionPointerThunk(
2253+
CGM, C, D->getType(), DestType);
2254+
if (Thunk) {
2255+
C = Thunk;
2256+
}
2257+
}
2258+
22462259
if (FD->getType()->isCFIUncheckedCalleeFunctionType())
22472260
C = llvm::NoCFIValue::get(cast<llvm::GlobalValue>(C));
22482261
return PtrAuthSign(C);

clang/lib/CodeGen/TargetInfo.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,17 @@ class TargetCodeGenInfo {
421421
/// Return the WebAssembly funcref reference type.
422422
virtual llvm::Type *getWasmFuncrefReferenceType() const { return nullptr; }
423423

424+
virtual const DeclRefExpr *getWasmFunctionDeclRefExpr(const Expr *E,
425+
ASTContext &Ctx) const {
426+
return nullptr;
427+
}
428+
429+
virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
430+
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
431+
QualType DstType) const {
432+
return nullptr;
433+
}
434+
424435
/// Emit the device-side copy of the builtin surface type.
425436
virtual bool emitCUDADeviceBuiltinSurfaceDeviceCopy(CodeGenFunction &CGF,
426437
LValue Dst,

clang/lib/CodeGen/Targets/WebAssembly.cpp

Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,14 @@
99
#include "ABIInfoImpl.h"
1010
#include "TargetInfo.h"
1111

12+
#include "clang/AST/ParentMapContext.h"
13+
#include <sstream>
14+
1215
using namespace clang;
1316
using namespace clang::CodeGen;
1417

18+
#define DEBUG_TYPE "clang-target-wasm"
19+
1520
//===----------------------------------------------------------------------===//
1621
// WebAssembly ABI Implementation
1722
//
@@ -93,6 +98,112 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
9398
virtual llvm::Type *getWasmFuncrefReferenceType() const override {
9499
return llvm::Type::getWasm_FuncrefTy(getABIInfo().getVMContext());
95100
}
101+
102+
virtual const DeclRefExpr *
103+
getWasmFunctionDeclRefExpr(const Expr *E, ASTContext &Ctx) const override {
104+
// Go down in the tree until finding the DeclRefExpr
105+
const DeclRefExpr *DRE = findDeclRefExpr(E);
106+
if (!DRE)
107+
return nullptr;
108+
109+
// Final case. The argument is a declared function
110+
if (isa<FunctionDecl>(DRE->getDecl())) {
111+
return DRE;
112+
}
113+
114+
// Complex case. The argument is a variable, we need to check
115+
// every assignment of the variable and see if we are bitcasting
116+
// or not.
117+
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
118+
DRE = findDeclRefExprForVarUp(E, VD, Ctx);
119+
if (DRE)
120+
return DRE;
121+
122+
// If no assignment exists on every parent scope, check for the
123+
// initialization
124+
if (!DRE && VD->hasInit()) {
125+
return getWasmFunctionDeclRefExpr(VD->getInit(), Ctx);
126+
}
127+
}
128+
129+
return nullptr;
130+
}
131+
132+
virtual llvm::Function *getOrCreateWasmFunctionPointerThunk(
133+
CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
134+
QualType DstType) const override {
135+
136+
// Get the signatures
137+
const FunctionProtoType *SrcProtoType = SrcType->getAs<FunctionProtoType>();
138+
const FunctionProtoType *DstProtoType = DstType->getAs<PointerType>()
139+
->getPointeeType()
140+
->getAs<FunctionProtoType>();
141+
142+
// This should only work for different number of arguments
143+
if (DstProtoType->getNumParams() <= SrcProtoType->getNumParams())
144+
return nullptr;
145+
146+
// Get the llvm function types
147+
llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
148+
CGM.getTypes().ConvertType(QualType(DstProtoType, 0)));
149+
llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
150+
CGM.getTypes().ConvertType(QualType(SrcProtoType, 0)));
151+
152+
// Construct the Thunk function with the Target (destination) signature
153+
std::string ThunkName = getThunkName(OriginalFnPtr->getName().str(),
154+
DstProtoType, CGM.getContext());
155+
llvm::Module &M = CGM.getModule();
156+
llvm::Function *Thunk = llvm::Function::Create(
157+
DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);
158+
159+
// Build the thunk body
160+
llvm::IRBuilder<> Builder(
161+
llvm::BasicBlock::Create(M.getContext(), "entry", Thunk));
162+
163+
// Gather the arguments for calling the original function
164+
std::vector<llvm::Value *> CallArgs;
165+
unsigned CallN = SrcProtoType->getNumParams();
166+
167+
auto ArgIt = Thunk->arg_begin();
168+
for (unsigned i = 0; i < CallN && ArgIt != Thunk->arg_end(); ++i, ++ArgIt) {
169+
llvm::Value *A = &*ArgIt;
170+
CallArgs.push_back(A);
171+
}
172+
173+
// Create the call to the original function pointer
174+
llvm::CallInst *Call =
175+
Builder.CreateCall(SrcFunctionType, OriginalFnPtr, CallArgs);
176+
177+
// Handle return type
178+
llvm::Type *ThunkRetTy = DstFunctionType->getReturnType();
179+
180+
if (ThunkRetTy->isVoidTy()) {
181+
Builder.CreateRetVoid();
182+
} else {
183+
llvm::Value *Ret = Call;
184+
if (Ret->getType() != ThunkRetTy)
185+
Ret = Builder.CreateBitCast(Ret, ThunkRetTy);
186+
Builder.CreateRet(Ret);
187+
}
188+
LLVM_DEBUG(llvm::dbgs() << "getOrCreateWasmFunctionPointerThunk:"
189+
<< " from " << OriginalFnPtr->getName().str()
190+
<< " to " << ThunkName << "\n");
191+
return Thunk;
192+
}
193+
194+
private:
195+
// Build the thunk name: "%s_{type1}_{type2}_..."
196+
std::string getThunkName(std::string OrigName,
197+
const FunctionProtoType *DstProto,
198+
const ASTContext &Ctx) const;
199+
std::string sanitizeTypeString(const std::string &typeStr) const;
200+
std::string getTypeName(const QualType &qt, const ASTContext &Ctx) const;
201+
const DeclRefExpr *findDeclRefExpr(const Expr *E) const;
202+
const DeclRefExpr *findDeclRefExprForVarDown(const Stmt *Parent,
203+
const VarDecl *V,
204+
ASTContext &Ctx) const;
205+
const DeclRefExpr *findDeclRefExprForVarUp(const Expr *E, const VarDecl *V,
206+
ASTContext &Ctx) const;
96207
};
97208

98209
/// Classify argument of given type \p Ty.
@@ -173,3 +284,114 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
173284
WebAssemblyABIKind K) {
174285
return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes(), K);
175286
}
287+
288+
// Helper to sanitize type name string for use in function name
289+
std::string WebAssemblyTargetCodeGenInfo::sanitizeTypeString(
290+
const std::string &typeStr) const {
291+
std::string s;
292+
for (char c : typeStr) {
293+
if (isalnum(c))
294+
s += c;
295+
else if (c == ' ')
296+
s += '_';
297+
else
298+
s += '_';
299+
}
300+
return s;
301+
}
302+
303+
// Helper to generate the type string from QualType
304+
std::string
305+
WebAssemblyTargetCodeGenInfo::getTypeName(const QualType &qt,
306+
const ASTContext &Ctx) const {
307+
PrintingPolicy Policy(Ctx.getLangOpts());
308+
Policy.SuppressTagKeyword = true;
309+
Policy.SuppressScope = true;
310+
Policy.AnonymousTagLocations = false;
311+
std::string typeStr = qt.getAsString(Policy);
312+
return sanitizeTypeString(typeStr);
313+
}
314+
315+
std::string
316+
WebAssemblyTargetCodeGenInfo::getThunkName(std::string OrigName,
317+
const FunctionProtoType *DstProto,
318+
const ASTContext &Ctx) const {
319+
std::ostringstream oss;
320+
oss << "__" << OrigName;
321+
for (unsigned i = 0; i < DstProto->getNumParams(); ++i) {
322+
oss << "_" << getTypeName(DstProto->getParamType(i), Ctx);
323+
}
324+
return oss.str();
325+
}
326+
327+
/// Recursively find the first DeclRefExpr in an Expr subtree.
328+
/// Returns nullptr if not found.
329+
const DeclRefExpr *
330+
WebAssemblyTargetCodeGenInfo::findDeclRefExpr(const Expr *E) const {
331+
if (!E)
332+
return nullptr;
333+
334+
// In case it is a function call, abort
335+
if (isa<CallExpr>(E))
336+
return nullptr;
337+
338+
// If this node is a DeclRefExpr, return it.
339+
if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
340+
return DRE;
341+
342+
// Otherwise, recurse into children.
343+
for (const Stmt *Child : E->children()) {
344+
if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
345+
if (const DeclRefExpr *Found = findDeclRefExpr(ChildExpr))
346+
return Found;
347+
}
348+
}
349+
return nullptr;
350+
}
351+
352+
const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown(
353+
const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
354+
if (!Parent)
355+
return nullptr;
356+
357+
// Find down every assignment of V
358+
// FIXME we need to stop before the expression where V is used
359+
const BinaryOperator *A = nullptr;
360+
for (const Stmt *Child : Parent->children()) {
361+
if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
362+
if (!BO->isAssignmentOp())
363+
continue;
364+
auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS());
365+
if (LHS && LHS->getDecl() == V) {
366+
A = BO;
367+
}
368+
}
369+
}
370+
371+
// We have an assignment of the Var, recurse in it
372+
if (A) {
373+
return getWasmFunctionDeclRefExpr(A->getRHS(), Ctx);
374+
}
375+
376+
return nullptr;
377+
}
378+
379+
const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp(
380+
const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
381+
const clang::Stmt *cur = E;
382+
while (cur) {
383+
auto parents = Ctx.getParentMapContext().getParents(*cur);
384+
if (parents.empty())
385+
break;
386+
const clang::Stmt *parentStmt = parents[0].get<clang::Stmt>();
387+
if (!parentStmt)
388+
break;
389+
if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
390+
const DeclRefExpr *DRE = findDeclRefExprForVarDown(CS, V, Ctx);
391+
if (DRE)
392+
return DRE;
393+
}
394+
cur = parentStmt;
395+
}
396+
return nullptr;
397+
}

0 commit comments

Comments
 (0)