Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions clang/lib/CodeGen/CGExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6583,6 +6583,25 @@ RValue CodeGenFunction::EmitCall(QualType CalleeType,
Address(Handle, Handle->getType(), CGM.getPointerAlign()));
Callee.setFunctionPointer(Stub);
}

// Insert function pointer lookup if this is a target call
//
// This is used for the indirect function case, virtual function case is
// handled in ItaniumCXXABI.cpp
if (getLangOpts().OpenMPIsTargetDevice && CGM.OMPTargetCalls.contains(E)) {
auto *PtrTy = CGM.VoidPtrTy;
llvm::Type *RtlFnArgs[] = {PtrTy};
llvm::FunctionCallee DeviceRtlFn = CGM.CreateRuntimeFunction(
llvm::FunctionType::get(PtrTy, RtlFnArgs, false),
"__kmpc_omp_indirect_call_lookup");
llvm::Value *Func = Callee.getFunctionPointer();
llvm::Type *BackupTy = Func->getType();
Func = Builder.CreatePointerBitCastOrAddrSpaceCast(Func, PtrTy);
Func = EmitRuntimeCall(DeviceRtlFn, {Func});
Func = Builder.CreatePointerBitCastOrAddrSpaceCast(Func, BackupTy);
Callee.setFunctionPointer(Func);
}

llvm::CallBase *LocalCallOrInvoke = nullptr;
RValue Call = EmitCall(FnInfo, Callee, ReturnValue, Args, &LocalCallOrInvoke,
E == MustTailCall, E->getExprLoc());
Expand Down
157 changes: 157 additions & 0 deletions clang/lib/CodeGen/CGOpenMPRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "clang/AST/OpenMPClause.h"
#include "clang/AST/StmtOpenMP.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/Basic/OpenMPKinds.h"
#include "clang/Basic/SourceManager.h"
#include "clang/CodeGen/ConstantInitBuilder.h"
Expand Down Expand Up @@ -1771,12 +1772,126 @@ void CGOpenMPRuntime::emitDeclareTargetFunction(const FunctionDecl *FD,
Addr->setVisibility(llvm::GlobalValue::ProtectedVisibility);
}

// Register the indirect Vtable:
// This is similar to OMPTargetGlobalVarEntryIndirect, except that the
// size field refers to the size of memory pointed to, not the size of
// the pointer symbol itself (which is implicitly the size of a pointer).
OMPBuilder.OffloadInfoManager.registerDeviceGlobalVarEntryInfo(
Name, Addr, CGM.GetTargetTypeStoreSize(CGM.VoidPtrTy).getQuantity(),
llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirect,
llvm::GlobalValue::WeakODRLinkage);
}

void CGOpenMPRuntime::registerVTableOffloadEntry(llvm::GlobalVariable *VTable,
const VarDecl *VD) {
// TODO: add logic to avoid duplicate vtable registrations per
// translation unit; though for external linkage, this should no
// longer be an issue - or at least we can avoid the issue by
// checking for an existing offloading entry. But, perhaps the
// better approach is to defer emission of the vtables and offload
// entries until later (by tracking a list of items that need to be
// emitted).

llvm::OpenMPIRBuilder &OMPBuilder = CGM.getOpenMPRuntime().getOMPBuilder();

// Generate a new externally visible global to point to the
// internally visible vtable. Doing this allows us to keep the
// visibility and linkage of the associated vtable unchanged while
// allowing the runtime to access its value. The externally
// visible global var needs to be emitted with a unique mangled
// name that won't conflict with similarly named (internal)
// vtables in other translation units.

// Register vtable with source location of dynamic object in map
// clause.
llvm::TargetRegionEntryInfo EntryInfo = getEntryInfoFromPresumedLoc(
CGM, OMPBuilder, VD->getCanonicalDecl()->getBeginLoc(),
VTable->getName());

llvm::GlobalVariable *Addr = VTable;
size_t PointerSize = CGM.getDataLayout().getPointerSize();
SmallString<128> AddrName;
OMPBuilder.OffloadInfoManager.getTargetRegionEntryFnName(AddrName, EntryInfo);
AddrName.append("addr");

if (CGM.getLangOpts().OpenMPIsTargetDevice) {
Addr = new llvm::GlobalVariable(
CGM.getModule(), VTable->getType(),
/*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, VTable,
AddrName,
/*InsertBefore*/ nullptr, llvm::GlobalValue::NotThreadLocal,
CGM.getModule().getDataLayout().getDefaultGlobalsAddressSpace());
Addr->setVisibility(llvm::GlobalValue::ProtectedVisibility);
}
OMPBuilder.OffloadInfoManager.registerDeviceGlobalVarEntryInfo(
AddrName, VTable,
CGM.getDataLayout().getTypeAllocSize(VTable->getInitializer()->getType()),
llvm::OffloadEntriesInfoManager::OMPTargetGlobalVarEntryIndirectVTable,
llvm::GlobalValue::WeakODRLinkage);
}

// Register VTable by scanning through the map clause of OpenMP target region.
void CGOpenMPRuntime::registerVTable(const OMPExecutableDirective &D) {
// Get CXXRecordDecl and VarDecl from Expr.
auto getVTableDecl = [](const Expr *E) {
QualType VDTy = E->getType();
CXXRecordDecl *CXXRecord = nullptr;
if (const auto *RefType = VDTy->getAs<LValueReferenceType>())
VDTy = RefType->getPointeeType();
if (VDTy->isPointerType())
CXXRecord = VDTy->getPointeeType()->getAsCXXRecordDecl();
else
CXXRecord = VDTy->getAsCXXRecordDecl();

const VarDecl *VD = nullptr;
if (auto *DRE = dyn_cast<DeclRefExpr>(E))
VD = cast<VarDecl>(DRE->getDecl());
return std::pair<CXXRecordDecl *, const VarDecl *>(CXXRecord, VD);
};

// Emit VTable and register the VTable to OpenMP offload entry recursively.
std::function<void(CodeGenModule &, CXXRecordDecl *, const VarDecl *)>
emitAndRegisterVTable = [&emitAndRegisterVTable](CodeGenModule &CGM,
CXXRecordDecl *CXXRecord,
const VarDecl *VD) {
// Register C++ VTable to OpenMP Offload Entry if it's a new
// CXXRecordDecl.
if (CXXRecord && CXXRecord->isDynamicClass() &&
CGM.getOpenMPRuntime().VTableDeclMap.find(CXXRecord) ==
CGM.getOpenMPRuntime().VTableDeclMap.end()) {
CGM.getOpenMPRuntime().VTableDeclMap.try_emplace(CXXRecord, VD);
CGM.EmitVTable(CXXRecord);
auto VTables = CGM.getVTables();
auto *VTablesAddr = VTables.GetAddrOfVTable(CXXRecord);
if (VTablesAddr) {
CGM.getOpenMPRuntime().registerVTableOffloadEntry(VTablesAddr, VD);
}
// Emit VTable for all the fields containing dynamic CXXRecord
for (const FieldDecl *Field : CXXRecord->fields()) {
if (CXXRecordDecl *RecordDecl =
Field->getType()->getAsCXXRecordDecl()) {
emitAndRegisterVTable(CGM, RecordDecl, VD);
}
}
// Emit VTable for all dynamic parent class
for (CXXBaseSpecifier &Base : CXXRecord->bases()) {
if (CXXRecordDecl *BaseDecl =
Base.getType()->getAsCXXRecordDecl()) {
emitAndRegisterVTable(CGM, BaseDecl, VD);
}
}
}
};

// Collect VTable from OpenMP map clause.
for (const auto *C : D.getClausesOfKind<OMPMapClause>()) {
for (const auto *E : C->varlist()) {
auto DeclPair = getVTableDecl(E);
emitAndRegisterVTable(CGM, DeclPair.first, DeclPair.second);
}
}
}

Address CGOpenMPRuntime::getAddrOfArtificialThreadPrivate(CodeGenFunction &CGF,
QualType VarType,
StringRef Name) {
Expand Down Expand Up @@ -6221,6 +6336,24 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
llvm::Function *&OutlinedFn, llvm::Constant *&OutlinedFnID,
bool IsOffloadEntry, const RegionCodeGenTy &CodeGen) {

class OMPTargetCallCollector
: public RecursiveASTVisitor<OMPTargetCallCollector> {
public:
OMPTargetCallCollector(CodeGenFunction &CGF,
llvm::SmallPtrSetImpl<const CallExpr *> &TargetCalls)
: CGF(CGF), TargetCalls(TargetCalls) {}

bool VisitCallExpr(CallExpr *CE) {
if (!CE->getDirectCallee())
TargetCalls.insert(CE);
return true;
}

private:
CodeGenFunction &CGF;
llvm::SmallPtrSetImpl<const CallExpr *> &TargetCalls;
};

llvm::TargetRegionEntryInfo EntryInfo =
getEntryInfoFromPresumedLoc(CGM, OMPBuilder, D.getBeginLoc(), ParentName);

Expand All @@ -6229,6 +6362,16 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
[&CGF, &D, &CodeGen](StringRef EntryFnName) {
const CapturedStmt &CS = *D.getCapturedStmt(OMPD_target);

// Search Clang AST within "omp target" region for CallExprs.
// Store them in the set OMPTargetCalls (kept by CodeGenModule).
// This is used for the translation of indirect function calls.
const auto &LangOpts = CGF.getLangOpts();
if (LangOpts.OpenMPIsTargetDevice) {
// Search AST for target "CallExpr"s of "OMPTargetAutoLookup".
OMPTargetCallCollector Visitor(CGF, CGF.CGM.OMPTargetCalls);
Visitor.TraverseStmt(const_cast<Stmt*>(CS.getCapturedStmt()));
}

CGOpenMPTargetRegionInfo CGInfo(CS, CodeGen, EntryFnName);
CodeGenFunction::CGCapturedStmtRAII CapInfoRAII(CGF, &CGInfo);
return CGF.GenerateOpenMPCapturedStmtFunction(CS, D);
Expand All @@ -6249,6 +6392,7 @@ void CGOpenMPRuntime::emitTargetOutlinedFunctionHelper(
CGM.handleAMDGPUWavesPerEUAttr(OutlinedFn, Attr);
}
}
registerVTable(D);
}

/// Checks if the expression is constant or does not have non-trivial function
Expand Down Expand Up @@ -9955,6 +10099,19 @@ void CGOpenMPRuntime::scanForTargetRegionsFunctions(const Stmt *S,
if (!S)
return;

// Register vtable from device for target data and target directives.
// Add this block here since scanForTargetRegionsFunctions ignores
// target data by checking if S is a executable directive (target).
if (isa<OMPExecutableDirective>(S) &&
isOpenMPTargetDataManagementDirective(
cast<OMPExecutableDirective>(S)->getDirectiveKind())) {
auto &E = *cast<OMPExecutableDirective>(S);
// Don't need to check if it's device compile
// since scanForTargetRegionsFunctions currently only called
// in device compilation.
registerVTable(E);
}

// Codegen OMP target directives that offload compute to the device.
bool RequiresDeviceCodegen =
isa<OMPExecutableDirective>(S) &&
Expand Down
13 changes: 13 additions & 0 deletions clang/lib/CodeGen/CGOpenMPRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ class CGOpenMPRuntime {
LValue PosLVal, const OMPTaskDataTy::DependData &Data,
Address DependenciesArray);

/// Keep track of VTable Declarations so we don't register duplicate VTable.
llvm::DenseMap<CXXRecordDecl*, const VarDecl*> VTableDeclMap;

public:
explicit CGOpenMPRuntime(CodeGenModule &CGM);
virtual ~CGOpenMPRuntime() {}
Expand Down Expand Up @@ -1111,6 +1114,16 @@ class CGOpenMPRuntime {
virtual void emitDeclareTargetFunction(const FunctionDecl *FD,
llvm::GlobalValue *GV);

/// Register VTable to OpenMP offload entry.
/// \param VTable VTable of the C++ class.
/// \param RD C++ class decl.
virtual void registerVTableOffloadEntry(llvm::GlobalVariable *VTable,
const VarDecl *VD);
/// Emit code for registering vtable by scanning through map clause
/// in OpenMP target region.
/// \param D OpenMP target directive.
virtual void registerVTable(const OMPExecutableDirective &D);

/// Creates artificial threadprivate variable with name \p Name and type \p
/// VarType.
/// \param VarType Type of the artificial threadprivate variable.
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CGStmtOpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7617,6 +7617,10 @@ void CodeGenFunction::EmitOMPUseDeviceAddrClause(
// Generate the instructions for '#pragma omp target data' directive.
void CodeGenFunction::EmitOMPTargetDataDirective(
const OMPTargetDataDirective &S) {
// Emit vtable only from host for target data directive.
if (!CGM.getLangOpts().OpenMPIsTargetDevice) {
CGM.getOpenMPRuntime().registerVTable(S);
}
CGOpenMPRuntime::TargetDataInfo Info(/*RequiresDevicePointerInfo=*/true,
/*SeparateBeginEndCalls=*/true);

Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CGVTables.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ llvm::Constant *CodeGenModule::GetAddrOfThunk(StringRef Name, llvm::Type *FnTy,
/*DontDefer=*/true, /*IsThunk=*/true);
}

llvm::GlobalVariable *CodeGenVTables::GetAddrOfVTable(const CXXRecordDecl *RD) {
llvm::GlobalVariable *VTable =
CGM.getCXXABI().getAddrOfVTable(RD, CharUnits());
return VTable;
}

static void setThunkProperties(CodeGenModule &CGM, const ThunkInfo &Thunk,
llvm::Function *ThunkFn, bool ForVTable,
GlobalDecl GD) {
Expand Down
4 changes: 4 additions & 0 deletions clang/lib/CodeGen/CGVTables.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class CodeGenVTables {
llvm::GlobalVariable::LinkageTypes Linkage,
const CXXRecordDecl *RD);

/// GetAddrOfVTable - Get the address of the VTable for the given record
/// decl.
llvm::GlobalVariable *GetAddrOfVTable(const CXXRecordDecl *RD);

/// EmitThunks - Emit the associated thunks for the given global decl.
void EmitThunks(GlobalDecl GD);

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -754,6 +754,9 @@ class CodeGenModule : public CodeGenTypeCache {
// i32 @__isPlatformVersionAtLeast(i32, i32, i32, i32)
llvm::FunctionCallee IsPlatformVersionAtLeastFn = nullptr;

// Store indirect CallExprs that are within an omp target region
llvm::SmallPtrSet<const CallExpr *, 16> OMPTargetCalls;

InstrProfStats &getPGOStats() { return PGOStats; }
llvm::IndexedInstrProfReader *getPGOReader() const { return PGOReader.get(); }

Expand Down
17 changes: 17 additions & 0 deletions clang/lib/CodeGen/ItaniumCXXABI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2261,6 +2261,23 @@ CGCallee ItaniumCXXABI::getVirtualFunctionPointer(CodeGenFunction &CGF,
llvm::Type *PtrTy = CGM.GlobalsInt8PtrTy;
auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl());
llvm::Value *VTable = CGF.GetVTablePtr(This, PtrTy, MethodDecl->getParent());

// For the translation of virtual functions, we need to map the (potential) host
// vtable to the device vtable. This is done by calling the runtime function
// __kmpc_omp_indirect_call_lookup.
if (CGM.getLangOpts().OpenMPIsTargetDevice) {
auto *NewPtrTy = CGM.VoidPtrTy;
llvm::Type *RtlFnArgs[] = {NewPtrTy};
llvm::FunctionCallee DeviceRtlFn = CGM.CreateRuntimeFunction(
llvm::FunctionType::get(NewPtrTy, RtlFnArgs, false),
"__kmpc_omp_indirect_call_lookup");
auto *BackupTy = VTable->getType();
// Need to convert to generic address space
VTable = CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(VTable, NewPtrTy);
VTable = CGF.EmitRuntimeCall(DeviceRtlFn, {VTable});
// convert to original address space
VTable = CGF.Builder.CreatePointerBitCastOrAddrSpaceCast(VTable, BackupTy);
}

uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD);
llvm::Value *VFunc, *VTableSlotPtr = nullptr;
Expand Down
Loading