99#include " ABIInfoImpl.h"
1010#include " TargetInfo.h"
1111
12+ #include " clang/AST/ParentMapContext.h"
13+ #include < sstream>
14+
1215using namespace clang ;
1316using namespace clang ::CodeGen;
1417
18+ #define DEBUG_TYPE " clang-target-wasm"
19+
1520// ===----------------------------------------------------------------------===//
1621// WebAssembly ABI Implementation
1722//
@@ -93,6 +98,112 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
9398 virtual llvm::Type *getWasmFuncrefReferenceType () const override {
9499 return llvm::Type::getWasm_FuncrefTy (getABIInfo ().getVMContext ());
95100 }
101+
102+ virtual const DeclRefExpr *
103+ getWasmFunctionDeclRefExpr (const Expr *E, ASTContext &Ctx) const override {
104+ // Go down in the tree until finding the DeclRefExpr
105+ const DeclRefExpr *DRE = findDeclRefExpr (E);
106+ if (!DRE)
107+ return nullptr ;
108+
109+ // Final case. The argument is a declared function
110+ if (isa<FunctionDecl>(DRE->getDecl ())) {
111+ return DRE;
112+ }
113+
114+ // Complex case. The argument is a variable, we need to check
115+ // every assignment of the variable and see if we are bitcasting
116+ // or not.
117+ if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
118+ DRE = findDeclRefExprForVarUp (E, VD, Ctx);
119+ if (DRE)
120+ return DRE;
121+
122+ // If no assignment exists on every parent scope, check for the
123+ // initialization
124+ if (!DRE && VD->hasInit ()) {
125+ return getWasmFunctionDeclRefExpr (VD->getInit (), Ctx);
126+ }
127+ }
128+
129+ return nullptr ;
130+ }
131+
132+ virtual llvm::Function *getOrCreateWasmFunctionPointerThunk (
133+ CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
134+ QualType DstType) const override {
135+
136+ // Get the signatures
137+ const FunctionProtoType *SrcProtoType = SrcType->getAs <FunctionProtoType>();
138+ const FunctionProtoType *DstProtoType = DstType->getAs <PointerType>()
139+ ->getPointeeType ()
140+ ->getAs <FunctionProtoType>();
141+
142+ // This should only work for different number of arguments
143+ if (DstProtoType->getNumParams () <= SrcProtoType->getNumParams ())
144+ return nullptr ;
145+
146+ // Get the llvm function types
147+ llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
148+ CGM.getTypes ().ConvertType (QualType (DstProtoType, 0 )));
149+ llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
150+ CGM.getTypes ().ConvertType (QualType (SrcProtoType, 0 )));
151+
152+ // Construct the Thunk function with the Target (destination) signature
153+ std::string ThunkName = getThunkName (OriginalFnPtr->getName ().str (),
154+ DstProtoType, CGM.getContext ());
155+ llvm::Module &M = CGM.getModule ();
156+ llvm::Function *Thunk = llvm::Function::Create (
157+ DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);
158+
159+ // Build the thunk body
160+ llvm::IRBuilder<> Builder (
161+ llvm::BasicBlock::Create (M.getContext (), " entry" , Thunk));
162+
163+ // Gather the arguments for calling the original function
164+ std::vector<llvm::Value *> CallArgs;
165+ unsigned CallN = SrcProtoType->getNumParams ();
166+
167+ auto ArgIt = Thunk->arg_begin ();
168+ for (unsigned i = 0 ; i < CallN && ArgIt != Thunk->arg_end (); ++i, ++ArgIt) {
169+ llvm::Value *A = &*ArgIt;
170+ CallArgs.push_back (A);
171+ }
172+
173+ // Create the call to the original function pointer
174+ llvm::CallInst *Call =
175+ Builder.CreateCall (SrcFunctionType, OriginalFnPtr, CallArgs);
176+
177+ // Handle return type
178+ llvm::Type *ThunkRetTy = DstFunctionType->getReturnType ();
179+
180+ if (ThunkRetTy->isVoidTy ()) {
181+ Builder.CreateRetVoid ();
182+ } else {
183+ llvm::Value *Ret = Call;
184+ if (Ret->getType () != ThunkRetTy)
185+ Ret = Builder.CreateBitCast (Ret, ThunkRetTy);
186+ Builder.CreateRet (Ret);
187+ }
188+ LLVM_DEBUG (llvm::dbgs () << " getOrCreateWasmFunctionPointerThunk:"
189+ << " from " << OriginalFnPtr->getName ().str ()
190+ << " to " << ThunkName << " \n " );
191+ return Thunk;
192+ }
193+
194+ private:
195+ // Build the thunk name: "%s_{type1}_{type2}_..."
196+ std::string getThunkName (std::string OrigName,
197+ const FunctionProtoType *DstProto,
198+ const ASTContext &Ctx) const ;
199+ std::string sanitizeTypeString (const std::string &typeStr) const ;
200+ std::string getTypeName (const QualType &qt, const ASTContext &Ctx) const ;
201+ const DeclRefExpr *findDeclRefExpr (const Expr *E) const ;
202+ const DeclRefExpr *findDeclRefExprForVarDown (const Stmt *Parent,
203+ const VarDecl *V,
204+ ASTContext &Ctx) const ;
205+ const DeclRefExpr *findDeclRefExprForVarUp (const Expr *E, const VarDecl *V,
206+ ASTContext &Ctx) const ;
96207};
97208
98209// / Classify argument of given type \p Ty.
@@ -173,3 +284,114 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
173284 WebAssemblyABIKind K) {
174285 return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes (), K);
175286}
287+
288+ // Helper to sanitize type name string for use in function name
289+ std::string WebAssemblyTargetCodeGenInfo::sanitizeTypeString (
290+ const std::string &typeStr) const {
291+ std::string s;
292+ for (char c : typeStr) {
293+ if (isalnum (c))
294+ s += c;
295+ else if (c == ' ' )
296+ s += ' _' ;
297+ else
298+ s += ' _' ;
299+ }
300+ return s;
301+ }
302+
303+ // Helper to generate the type string from QualType
304+ std::string
305+ WebAssemblyTargetCodeGenInfo::getTypeName (const QualType &qt,
306+ const ASTContext &Ctx) const {
307+ PrintingPolicy Policy (Ctx.getLangOpts ());
308+ Policy.SuppressTagKeyword = true ;
309+ Policy.SuppressScope = true ;
310+ Policy.AnonymousTagLocations = false ;
311+ std::string typeStr = qt.getAsString (Policy);
312+ return sanitizeTypeString (typeStr);
313+ }
314+
315+ std::string
316+ WebAssemblyTargetCodeGenInfo::getThunkName (std::string OrigName,
317+ const FunctionProtoType *DstProto,
318+ const ASTContext &Ctx) const {
319+ std::ostringstream oss;
320+ oss << " __" << OrigName;
321+ for (unsigned i = 0 ; i < DstProto->getNumParams (); ++i) {
322+ oss << " _" << getTypeName (DstProto->getParamType (i), Ctx);
323+ }
324+ return oss.str ();
325+ }
326+
327+ // / Recursively find the first DeclRefExpr in an Expr subtree.
328+ // / Returns nullptr if not found.
329+ const DeclRefExpr *
330+ WebAssemblyTargetCodeGenInfo::findDeclRefExpr (const Expr *E) const {
331+ if (!E)
332+ return nullptr ;
333+
334+ // In case it is a function call, abort
335+ if (isa<CallExpr>(E))
336+ return nullptr ;
337+
338+ // If this node is a DeclRefExpr, return it.
339+ if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
340+ return DRE;
341+
342+ // Otherwise, recurse into children.
343+ for (const Stmt *Child : E->children ()) {
344+ if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
345+ if (const DeclRefExpr *Found = findDeclRefExpr (ChildExpr))
346+ return Found;
347+ }
348+ }
349+ return nullptr ;
350+ }
351+
352+ const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown (
353+ const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
354+ if (!Parent)
355+ return nullptr ;
356+
357+ // Find down every assignment of V
358+ // FIXME we need to stop before the expression where V is used
359+ const BinaryOperator *A = nullptr ;
360+ for (const Stmt *Child : Parent->children ()) {
361+ if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
362+ if (!BO->isAssignmentOp ())
363+ continue ;
364+ auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS ());
365+ if (LHS && LHS->getDecl () == V) {
366+ A = BO;
367+ }
368+ }
369+ }
370+
371+ // We have an assignment of the Var, recurse in it
372+ if (A) {
373+ return getWasmFunctionDeclRefExpr (A->getRHS (), Ctx);
374+ }
375+
376+ return nullptr ;
377+ }
378+
379+ const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp (
380+ const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
381+ const clang::Stmt *cur = E;
382+ while (cur) {
383+ auto parents = Ctx.getParentMapContext ().getParents (*cur);
384+ if (parents.empty ())
385+ break ;
386+ const clang::Stmt *parentStmt = parents[0 ].get <clang::Stmt>();
387+ if (!parentStmt)
388+ break ;
389+ if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
390+ const DeclRefExpr *DRE = findDeclRefExprForVarDown (CS, V, Ctx);
391+ if (DRE)
392+ return DRE;
393+ }
394+ cur = parentStmt;
395+ }
396+ return nullptr ;
397+ }
0 commit comments