-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[OpenMP][clang] Indirect and Virtual function call mapping from host to device #159857
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6583,6 +6583,26 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType, | |
| Address(Handle, Handle->getType(), CGM.getPointerAlign())); | ||
| Callee.setFunctionPointer(Stub); | ||
| } | ||
|
|
||
| // Check whether the associated CallExpr is in the set OMPTargetCalls. | ||
| // If YES, insert a call to devicertl function __llvm_omp_indirect_call_lookup | ||
| // | ||
| // This is used for the indriect function Case, virtual function case is | ||
| // handled in ItaniumCXXABI.cpp | ||
| if (getLangOpts().OpenMPIsTargetDevice && CGM.OMPTargetCalls.contains(E)) { | ||
| auto *PtrTy = CGM.VoidPtrTy; | ||
| llvm::Type *RtlFnArgs[] = {PtrTy}; | ||
| llvm::FunctionCallee DeviceRtlFn = CGM.CreateRuntimeFunction( | ||
| llvm::FunctionType::get(PtrTy, RtlFnArgs, false), | ||
| "__llvm_omp_indirect_call_lookup"); | ||
|
||
| llvm::Value *Func = Callee.getFunctionPointer(); | ||
| llvm::Type *BackupTy = Func->getType(); | ||
| Func = Builder.CreatePointerBitCastOrAddrSpaceCast(Func, PtrTy); | ||
| Func = EmitRuntimeCall(DeviceRtlFn, {Func}); | ||
| Func = Builder.CreatePointerBitCastOrAddrSpaceCast(Func, BackupTy); | ||
| Callee.setFunctionPointer(Func); | ||
| } | ||
|
|
||
| llvm::CallBase *LocalCallOrInvoke = nullptr; | ||
| RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke, | ||
| E == MustTailCall, E->getExprLoc()); | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |||||||||||
| #include "clang/AST/OpenMPClause.h" | ||||||||||||
| #include "clang/AST/StmtOpenMP.h" | ||||||||||||
| #include "clang/AST/StmtVisitor.h" | ||||||||||||
| #include "clang/AST/RecursiveASTVisitor.h" | ||||||||||||
| #include "clang/Basic/OpenMPKinds.h" | ||||||||||||
| #include "clang/Basic/SourceManager.h" | ||||||||||||
| #include "clang/CodeGen/ConstantInitBuilder.h" | ||||||||||||
|
|
@@ -1771,12 +1772,126 @@ void CGOpenMPRuntime::emitDeclareTargetFunction(const FunctionDecl *FD, | |||||||||||
| Addr->setVisibility(llvm::GlobalValue::ProtectedVisibility); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Register the indirect Vtable: | ||||||||||||
| // This is similar to OMPTargetGlobalVarEntryIndirect, except that the | ||||||||||||
| // size field refers to the size of memory pointed to, not the size of | ||||||||||||
| // the pointer symbol itself (which is implicitly the size of a pointer). | ||||||||||||
| OMPBuilder.OffloadInfoManager.registerDeviceGlobalVarEntryInfo( | ||||||||||||
| Name, Addr, CGM.GetTargetTypeStoreSize(CGM.VoidPtrTy).getQuantity(), | ||||||||||||
| llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect, | ||||||||||||
| llvm::GlobalValue::WeakODRLinkage); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| void CGOpenMPRuntime::registerVTableOffloadEntry(llvm::GlobalVariable *VTable, | ||||||||||||
| const VarDecl *VD) { | ||||||||||||
| // TODO: add logic to avoid duplicate vtable registrations per | ||||||||||||
| // translation unit; though for external linkage, this should no | ||||||||||||
| // longer be an issue - or at least we can avoid the issue by | ||||||||||||
| // checking for an existing offloading entry. But, perhaps the | ||||||||||||
| // better approach is to defer emission of the vtables and offload | ||||||||||||
| // entries until later (by tracking a list of items that need to be | ||||||||||||
| // emitted). | ||||||||||||
|
|
||||||||||||
| llvm::OpenMPIRBuilder &OMPBuilder = CGM.getOpenMPRuntime().getOMPBuilder(); | ||||||||||||
|
|
||||||||||||
| // Generate a new externally visible global to point to the | ||||||||||||
| // internally visible vtable. Doing this allows us to keep the | ||||||||||||
| // visibility and linkage of the associated vtable unchanged while | ||||||||||||
| // allowing the runtime to access its value. The externally | ||||||||||||
| // visible global var needs to be emitted with a unique mangled | ||||||||||||
| // name that won't conflict with similarly named (internal) | ||||||||||||
| // vtables in other translation units. | ||||||||||||
|
|
||||||||||||
| // Register vtable with source location of dynamic object in map | ||||||||||||
| // clause. | ||||||||||||
| llvm::TargetRegionEntryInfo EntryInfo = getEntryInfoFromPresumedLoc( | ||||||||||||
| CGM, OMPBuilder, VD->getCanonicalDecl()->getBeginLoc(), | ||||||||||||
| VTable->getName()); | ||||||||||||
|
|
||||||||||||
| llvm::GlobalVariable *Addr = VTable; | ||||||||||||
| size_t PointerSize = CGM.getDataLayout().getPointerSize(); | ||||||||||||
| SmallString<128> AddrName; | ||||||||||||
| OMPBuilder.OffloadInfoManager.getTargetRegionEntryFnName(AddrName, EntryInfo); | ||||||||||||
| AddrName.append("addr"); | ||||||||||||
|
|
||||||||||||
| if (CGM.getLangOpts().OpenMPIsTargetDevice) { | ||||||||||||
| Addr = new llvm::GlobalVariable( | ||||||||||||
| CGM.getModule(), VTable->getType(), | ||||||||||||
| /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, VTable, | ||||||||||||
| AddrName, | ||||||||||||
| /*InsertBefore*/ nullptr, llvm::GlobalValue::NotThreadLocal, | ||||||||||||
| CGM.getModule().getDataLayout().getDefaultGlobalsAddressSpace()); | ||||||||||||
| Addr->setVisibility(llvm::GlobalValue::ProtectedVisibility); | ||||||||||||
| } | ||||||||||||
| OMPBuilder.OffloadInfoManager.registerDeviceGlobalVarEntryInfo( | ||||||||||||
| AddrName, VTable, | ||||||||||||
| CGM.getDataLayout().getTypeAllocSize(VTable->getInitializer()->getType()), | ||||||||||||
| llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable, | ||||||||||||
| llvm::GlobalValue::WeakODRLinkage); | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| // Register VTable by scanning through the map clause of OpenMP target region. | ||||||||||||
| void CGOpenMPRuntime::registerVTable(const OMPExecutableDirective &D) { | ||||||||||||
| // Get CXXRecordDecl and VarDecl from Expr. | ||||||||||||
| auto getVTableDecl = [](const Expr *E) { | ||||||||||||
| QualType VDTy = E->getType(); | ||||||||||||
| CXXRecordDecl *CXXRecord = nullptr; | ||||||||||||
| if (const auto *RefType = VDTy->getAs<LValueReferenceType>()) | ||||||||||||
| VDTy = RefType->getPointeeType(); | ||||||||||||
| if (VDTy->isPointerType()) | ||||||||||||
| CXXRecord = VDTy->getPointeeType()->getAsCXXRecordDecl(); | ||||||||||||
| else | ||||||||||||
| CXXRecord = VDTy->getAsCXXRecordDecl(); | ||||||||||||
|
|
||||||||||||
| const VarDecl *VD = nullptr; | ||||||||||||
| if (auto *DRE = dyn_cast<DeclRefExpr>(E)) | ||||||||||||
| VD = cast<VarDecl>(DRE->getDecl()); | ||||||||||||
| return std::pair<CXXRecordDecl *, const VarDecl *>(CXXRecord, VD); | ||||||||||||
| }; | ||||||||||||
|
|
||||||||||||
| // Emit VTable and register the VTable to OpenMP offload entry recursively. | ||||||||||||
| std::function<void(CodeGenModule &, CXXRecordDecl *, const VarDecl *)> | ||||||||||||
| emitAndRegisterVTable = [&emitAndRegisterVTable](CodeGenModule &CGM, | ||||||||||||
| CXXRecordDecl *CXXRecord, | ||||||||||||
| const VarDecl *VD) { | ||||||||||||
| // Register C++ VTable to OpenMP Offload Entry if it's a new | ||||||||||||
| // CXXRecordDecl. | ||||||||||||
| if (CXXRecord && CXXRecord->isDynamicClass() && | ||||||||||||
| CGM.getOpenMPRuntime().VTableDeclMap.find(CXXRecord) == | ||||||||||||
| CGM.getOpenMPRuntime().VTableDeclMap.end()) { | ||||||||||||
| CGM.getOpenMPRuntime().VTableDeclMap.try_emplace(CXXRecord, VD); | ||||||||||||
| CGM.EmitVTable(CXXRecord); | ||||||||||||
| auto VTables = CGM.getVTables(); | ||||||||||||
| auto *VTablesAddr = VTables.GetAddrOfVTable(CXXRecord); | ||||||||||||
| if (VTablesAddr) { | ||||||||||||
| CGM.getOpenMPRuntime().registerVTableOffloadEntry(VTablesAddr, VD); | ||||||||||||
| } | ||||||||||||
| // Emit VTable for all the fields containing dynamic CXXRecord | ||||||||||||
| for (const FieldDecl *Field : CXXRecord->fields()) { | ||||||||||||
| if (CXXRecordDecl *RecordDecl = | ||||||||||||
| Field->getType()->getAsCXXRecordDecl()) { | ||||||||||||
| emitAndRegisterVTable(CGM, RecordDecl, VD); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| // Emit VTable for all dynamic parent class | ||||||||||||
| for (CXXBaseSpecifier &Base : CXXRecord->bases()) { | ||||||||||||
| if (CXXRecordDecl *BaseDecl = | ||||||||||||
| Base.getType()->getAsCXXRecordDecl()) { | ||||||||||||
| emitAndRegisterVTable(CGM, BaseDecl, VD); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| }; | ||||||||||||
|
|
||||||||||||
| // Collect VTable from OpenMP map clause. | ||||||||||||
| for (const auto *C : D.getClausesOfKind<OMPMapClause>()) { | ||||||||||||
| for (const auto *E : C->varlist()) { | ||||||||||||
| auto DeclPair = getVTableDecl(E); | ||||||||||||
| emitAndRegisterVTable(CGM, DeclPair.first, DeclPair.second); | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
| } | ||||||||||||
|
|
||||||||||||
| Address CGOpenMPRuntime::getAddrOfArtificialThreadPrivate(CodeGenFunction &CGF, | ||||||||||||
| QualType VarType, | ||||||||||||
| StringRef Name) { | ||||||||||||
|
|
@@ -6221,6 +6336,25 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper( | |||||||||||
| llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID, | ||||||||||||
| bool IsOffloadEntry, const RegionCodeGenTy &CodeGen) { | ||||||||||||
|
|
||||||||||||
| class OMPTargetCallCollector | ||||||||||||
| : public RecursiveASTVisitor<OMPTargetCallCollector> { | ||||||||||||
| public: | ||||||||||||
| OMPTargetCallCollector(CodeGenFunction &CGF, | ||||||||||||
| llvm::SmallPtrSetImpl<const CallExpr *> &TargetCalls) | ||||||||||||
| : CGF(CGF), TargetCalls(TargetCalls) {} | ||||||||||||
|
|
||||||||||||
| bool VisitCallExpr(CallExpr *CE) { | ||||||||||||
| if (!CE->getDirectCallee()) { | ||||||||||||
| TargetCalls.insert(CE); | ||||||||||||
| } | ||||||||||||
|
||||||||||||
| if (!CE->getDirectCallee()) { | |
| TargetCalls.insert(CE); | |
| } | |
| if (!CE->getDirectCallee()) | |
| TargetCalls.insert(CE); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed in 11b1f08
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2261,6 +2261,24 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF, | |
| llvm::Type *PtrTy = CGM.GlobalsInt8PtrTy; | ||
| auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl()); | ||
| llvm::Value *VTable = CGF.GetVTablePtr(This, PtrTy, MethodDecl->getParent()); | ||
| /* | ||
| * For the translate of virtual functions we need to map the (potential) host vtable | ||
| * to the device vtable. This is done by calling the runtime function | ||
| * __llvm_omp_indirect_call_lookup. | ||
| */ | ||
|
||
| if (CGM.getLangOpts().OpenMPIsTargetDevice) { | ||
| auto *NewPtrTy = CGM.VoidPtrTy; | ||
| llvm::Type *RtlFnArgs[] = {NewPtrTy}; | ||
| llvm::FunctionCallee DeviceRtlFn = CGM.CreateRuntimeFunction( | ||
| llvm::FunctionType::get(NewPtrTy, RtlFnArgs, false), | ||
| "__llvm_omp_indirect_call_lookup"); | ||
| auto *BackupTy = VTable->getType(); | ||
| // Need to convert to generic address space | ||
| VTable = CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(VTable, NewPtrTy); | ||
| VTable = CGF.EmitRuntimeCall(DeviceRtlFn, {VTable}); | ||
| // convert to original address space | ||
| VTable = CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(VTable, BackupTy); | ||
| } | ||
|
|
||
| uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD); | ||
| llvm::Value *VFunc, *VTableSlotPtr = nullptr; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment could be more straightforward, just
insert functoin pointer lookup if this is a target callor something.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replaced comment with suggested comment in 11b1f08.