9
9
#include " ABIInfoImpl.h"
10
10
#include " TargetInfo.h"
11
11
12
+ #include " clang/AST/ParentMapContext.h"
13
+ #include < sstream>
14
+
12
15
using namespace clang ;
13
16
using namespace clang ::CodeGen;
14
17
18
+ #define DEBUG_TYPE " clang-target-wasm"
19
+
15
20
// ===----------------------------------------------------------------------===//
16
21
// WebAssembly ABI Implementation
17
22
//
@@ -93,6 +98,112 @@ class WebAssemblyTargetCodeGenInfo final : public TargetCodeGenInfo {
93
98
virtual llvm::Type *getWasmFuncrefReferenceType () const override {
94
99
return llvm::Type::getWasm_FuncrefTy (getABIInfo ().getVMContext ());
95
100
}
101
+
102
+ virtual const DeclRefExpr *
103
+ getWasmFunctionDeclRefExpr (const Expr *E, ASTContext &Ctx) const override {
104
+ // Go down in the tree until finding the DeclRefExpr
105
+ const DeclRefExpr *DRE = findDeclRefExpr (E);
106
+ if (!DRE)
107
+ return nullptr ;
108
+
109
+ // Final case. The argument is a declared function
110
+ if (isa<FunctionDecl>(DRE->getDecl ())) {
111
+ return DRE;
112
+ }
113
+
114
+ // Complex case. The argument is a variable, we need to check
115
+ // every assignment of the variable and see if we are bitcasting
116
+ // or not.
117
+ if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl ())) {
118
+ DRE = findDeclRefExprForVarUp (E, VD, Ctx);
119
+ if (DRE)
120
+ return DRE;
121
+
122
+ // If no assignment exists on every parent scope, check for the
123
+ // initialization
124
+ if (!DRE && VD->hasInit ()) {
125
+ return getWasmFunctionDeclRefExpr (VD->getInit (), Ctx);
126
+ }
127
+ }
128
+
129
+ return nullptr ;
130
+ }
131
+
132
+ virtual llvm::Function *getOrCreateWasmFunctionPointerThunk (
133
+ CodeGenModule &CGM, llvm::Value *OriginalFnPtr, QualType SrcType,
134
+ QualType DstType) const override {
135
+
136
+ // Get the signatures
137
+ const FunctionProtoType *SrcProtoType = SrcType->getAs <FunctionProtoType>();
138
+ const FunctionProtoType *DstProtoType = DstType->getAs <PointerType>()
139
+ ->getPointeeType ()
140
+ ->getAs <FunctionProtoType>();
141
+
142
+ // This should only work for different number of arguments
143
+ if (DstProtoType->getNumParams () <= SrcProtoType->getNumParams ())
144
+ return nullptr ;
145
+
146
+ // Get the llvm function types
147
+ llvm::FunctionType *DstFunctionType = llvm::cast<llvm::FunctionType>(
148
+ CGM.getTypes ().ConvertType (QualType (DstProtoType, 0 )));
149
+ llvm::FunctionType *SrcFunctionType = llvm::cast<llvm::FunctionType>(
150
+ CGM.getTypes ().ConvertType (QualType (SrcProtoType, 0 )));
151
+
152
+ // Construct the Thunk function with the Target (destination) signature
153
+ std::string ThunkName = getThunkName (OriginalFnPtr->getName ().str (),
154
+ DstProtoType, CGM.getContext ());
155
+ llvm::Module &M = CGM.getModule ();
156
+ llvm::Function *Thunk = llvm::Function::Create (
157
+ DstFunctionType, llvm::Function::InternalLinkage, ThunkName, M);
158
+
159
+ // Build the thunk body
160
+ llvm::IRBuilder<> Builder (
161
+ llvm::BasicBlock::Create (M.getContext (), " entry" , Thunk));
162
+
163
+ // Gather the arguments for calling the original function
164
+ std::vector<llvm::Value *> CallArgs;
165
+ unsigned CallN = SrcProtoType->getNumParams ();
166
+
167
+ auto ArgIt = Thunk->arg_begin ();
168
+ for (unsigned i = 0 ; i < CallN && ArgIt != Thunk->arg_end (); ++i, ++ArgIt) {
169
+ llvm::Value *A = &*ArgIt;
170
+ CallArgs.push_back (A);
171
+ }
172
+
173
+ // Create the call to the original function pointer
174
+ llvm::CallInst *Call =
175
+ Builder.CreateCall (SrcFunctionType, OriginalFnPtr, CallArgs);
176
+
177
+ // Handle return type
178
+ llvm::Type *ThunkRetTy = DstFunctionType->getReturnType ();
179
+
180
+ if (ThunkRetTy->isVoidTy ()) {
181
+ Builder.CreateRetVoid ();
182
+ } else {
183
+ llvm::Value *Ret = Call;
184
+ if (Ret->getType () != ThunkRetTy)
185
+ Ret = Builder.CreateBitCast (Ret, ThunkRetTy);
186
+ Builder.CreateRet (Ret);
187
+ }
188
+ LLVM_DEBUG (llvm::dbgs () << " getOrCreateWasmFunctionPointerThunk:"
189
+ << " from " << OriginalFnPtr->getName ().str ()
190
+ << " to " << ThunkName << " \n " );
191
+ return Thunk;
192
+ }
193
+
194
+ private:
195
+ // Build the thunk name: "%s_{type1}_{type2}_..."
196
+ std::string getThunkName (std::string OrigName,
197
+ const FunctionProtoType *DstProto,
198
+ const ASTContext &Ctx) const ;
199
+ std::string sanitizeTypeString (const std::string &typeStr) const ;
200
+ std::string getTypeName (const QualType &qt, const ASTContext &Ctx) const ;
201
+ const DeclRefExpr *findDeclRefExpr (const Expr *E) const ;
202
+ const DeclRefExpr *findDeclRefExprForVarDown (const Stmt *Parent,
203
+ const VarDecl *V,
204
+ ASTContext &Ctx) const ;
205
+ const DeclRefExpr *findDeclRefExprForVarUp (const Expr *E, const VarDecl *V,
206
+ ASTContext &Ctx) const ;
96
207
};
97
208
98
209
// / Classify argument of given type \p Ty.
@@ -173,3 +284,114 @@ CodeGen::createWebAssemblyTargetCodeGenInfo(CodeGenModule &CGM,
173
284
WebAssemblyABIKind K) {
174
285
return std::make_unique<WebAssemblyTargetCodeGenInfo>(CGM.getTypes (), K);
175
286
}
287
+
288
+ // Helper to sanitize type name string for use in function name
289
+ std::string WebAssemblyTargetCodeGenInfo::sanitizeTypeString (
290
+ const std::string &typeStr) const {
291
+ std::string s;
292
+ for (char c : typeStr) {
293
+ if (isalnum (c))
294
+ s += c;
295
+ else if (c == ' ' )
296
+ s += ' _' ;
297
+ else
298
+ s += ' _' ;
299
+ }
300
+ return s;
301
+ }
302
+
303
+ // Helper to generate the type string from QualType
304
+ std::string
305
+ WebAssemblyTargetCodeGenInfo::getTypeName (const QualType &qt,
306
+ const ASTContext &Ctx) const {
307
+ PrintingPolicy Policy (Ctx.getLangOpts ());
308
+ Policy.SuppressTagKeyword = true ;
309
+ Policy.SuppressScope = true ;
310
+ Policy.AnonymousTagLocations = false ;
311
+ std::string typeStr = qt.getAsString (Policy);
312
+ return sanitizeTypeString (typeStr);
313
+ }
314
+
315
+ std::string
316
+ WebAssemblyTargetCodeGenInfo::getThunkName (std::string OrigName,
317
+ const FunctionProtoType *DstProto,
318
+ const ASTContext &Ctx) const {
319
+ std::ostringstream oss;
320
+ oss << " __" << OrigName;
321
+ for (unsigned i = 0 ; i < DstProto->getNumParams (); ++i) {
322
+ oss << " _" << getTypeName (DstProto->getParamType (i), Ctx);
323
+ }
324
+ return oss.str ();
325
+ }
326
+
327
+ // / Recursively find the first DeclRefExpr in an Expr subtree.
328
+ // / Returns nullptr if not found.
329
+ const DeclRefExpr *
330
+ WebAssemblyTargetCodeGenInfo::findDeclRefExpr (const Expr *E) const {
331
+ if (!E)
332
+ return nullptr ;
333
+
334
+ // In case it is a function call, abort
335
+ if (isa<CallExpr>(E))
336
+ return nullptr ;
337
+
338
+ // If this node is a DeclRefExpr, return it.
339
+ if (const auto *DRE = dyn_cast<DeclRefExpr>(E))
340
+ return DRE;
341
+
342
+ // Otherwise, recurse into children.
343
+ for (const Stmt *Child : E->children ()) {
344
+ if (const auto *ChildExpr = dyn_cast_or_null<Expr>(Child)) {
345
+ if (const DeclRefExpr *Found = findDeclRefExpr (ChildExpr))
346
+ return Found;
347
+ }
348
+ }
349
+ return nullptr ;
350
+ }
351
+
352
+ const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarDown (
353
+ const Stmt *Parent, const VarDecl *V, ASTContext &Ctx) const {
354
+ if (!Parent)
355
+ return nullptr ;
356
+
357
+ // Find down every assignment of V
358
+ // FIXME we need to stop before the expression where V is used
359
+ const BinaryOperator *A = nullptr ;
360
+ for (const Stmt *Child : Parent->children ()) {
361
+ if (const auto *BO = dyn_cast_or_null<BinaryOperator>(Child)) {
362
+ if (!BO->isAssignmentOp ())
363
+ continue ;
364
+ auto *LHS = llvm::dyn_cast<DeclRefExpr>(BO->getLHS ());
365
+ if (LHS && LHS->getDecl () == V) {
366
+ A = BO;
367
+ }
368
+ }
369
+ }
370
+
371
+ // We have an assignment of the Var, recurse in it
372
+ if (A) {
373
+ return getWasmFunctionDeclRefExpr (A->getRHS (), Ctx);
374
+ }
375
+
376
+ return nullptr ;
377
+ }
378
+
379
+ const DeclRefExpr *WebAssemblyTargetCodeGenInfo::findDeclRefExprForVarUp (
380
+ const Expr *E, const VarDecl *V, ASTContext &Ctx) const {
381
+ const clang::Stmt *cur = E;
382
+ while (cur) {
383
+ auto parents = Ctx.getParentMapContext ().getParents (*cur);
384
+ if (parents.empty ())
385
+ break ;
386
+ const clang::Stmt *parentStmt = parents[0 ].get <clang::Stmt>();
387
+ if (!parentStmt)
388
+ break ;
389
+ if (const auto *CS = dyn_cast<clang::CompoundStmt>(parentStmt)) {
390
+ const DeclRefExpr *DRE = findDeclRefExprForVarDown (CS, V, Ctx);
391
+ if (DRE)
392
+ return DRE;
393
+ }
394
+ cur = parentStmt;
395
+ }
396
+ return nullptr ;
397
+ }
0 commit comments