Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
MLIRIR
MLIRLinalgDialect
MLIRMemRefDialect
MLIRMemRefTransforms
MLIRMemRefUtils
MLIRSCFDialect
MLIRSideEffectInterfaces
Expand Down
2 changes: 0 additions & 2 deletions offload/include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"

#include "GlobalHandler.h"
#include "PluginInterface.h"

using GenericPluginTy = llvm::omp::target::plugin::GenericPluginTy;

// Forward declarations.
Expand Down
84 changes: 1 addition & 83 deletions offload/libomptarget/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
using namespace llvm::omp::target::ompt;
#endif

using namespace llvm::omp::target::plugin;

int HostDataToTargetTy::addEventIfNecessary(DeviceTy &Device,
AsyncInfoTy &AsyncInfo) const {
// First, check if the user disabled atomic map transfer/malloc/dealloc.
Expand Down Expand Up @@ -99,94 +97,14 @@ llvm::Error DeviceTy::init() {
return llvm::Error::success();
}

// Extract the mapping of host function pointers to device function pointers
// from the entry table. Functions marked as 'indirect' in OpenMP will have
// offloading entries generated for them which map the host's function pointer
// to a global containing the corresponding function pointer on the device.
static llvm::Expected<std::pair<void *, uint64_t>>
setupIndirectCallTable(DeviceTy &Device, __tgt_device_image *Image,
__tgt_device_binary Binary) {
AsyncInfoTy AsyncInfo(Device);
llvm::ArrayRef<llvm::offloading::EntryTy> Entries(Image->EntriesBegin,
Image->EntriesEnd);
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
for (const auto &Entry : Entries) {
if (Entry.Kind != llvm::object::OffloadKind::OFK_OpenMP ||
Entry.Size == 0 || !(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
continue;

assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();

void *Ptr;
if (Device.RTL->get_global(Binary, Entry.Size, Entry.SymbolName, &Ptr))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);

HstPtr = Entry.Address;
if (Device.retrieveData(&DevPtr, Ptr, Entry.Size, AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load %s", Entry.SymbolName);
}

// If we do not have any indirect globals we exit early.
if (IndirectCallTable.empty())
return std::pair{nullptr, 0};

// Sort the array to allow for more efficient lookup of device pointers.
llvm::sort(IndirectCallTable,
[](const auto &x, const auto &y) { return x.first < y.first; });

uint64_t TableSize =
IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
void *DevicePtr = Device.allocData(TableSize, nullptr, TARGET_ALLOC_DEVICE);
if (Device.submitData(DevicePtr, IndirectCallTable.data(), TableSize,
AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to copy data");
return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
}

// Load binary to device and perform global initialization if needed.
// Load binary to device.
llvm::Expected<__tgt_device_binary>
DeviceTy::loadBinary(__tgt_device_image *Img) {
__tgt_device_binary Binary;

if (RTL->load_binary(RTLDeviceID, Img, &Binary) != OFFLOAD_SUCCESS)
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to load binary %p", Img);

// This symbol is optional.
void *DeviceEnvironmentPtr;
if (RTL->get_global(Binary, sizeof(DeviceEnvironmentTy),
"__omp_rtl_device_environment", &DeviceEnvironmentPtr))
return Binary;

// Obtain a table mapping host function pointers to device function pointers.
auto CallTablePairOrErr = setupIndirectCallTable(*this, Img, Binary);
if (!CallTablePairOrErr)
return CallTablePairOrErr.takeError();

GenericDeviceTy &GenericDevice = RTL->getDevice(RTLDeviceID);
DeviceEnvironmentTy DeviceEnvironment;
DeviceEnvironment.DeviceDebugKind = GenericDevice.getDebugKind();
DeviceEnvironment.NumDevices = RTL->getNumDevices();
// TODO: The device ID used here is not the real device ID used by OpenMP.
DeviceEnvironment.DeviceNum = RTLDeviceID;
DeviceEnvironment.DynamicMemSize = GenericDevice.getDynamicMemorySize();
DeviceEnvironment.ClockFrequency = GenericDevice.getClockFrequency();
DeviceEnvironment.IndirectCallTable =
reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
DeviceEnvironment.HardwareParallelism =
GenericDevice.getHardwareParallelism();

AsyncInfoTy AsyncInfo(*this);
if (submitData(DeviceEnvironmentPtr, &DeviceEnvironment,
sizeof(DeviceEnvironment), AsyncInfo))
return error::createOffloadError(error::ErrorCode::INVALID_BINARY,
"failed to copy data");

return Binary;
}

Expand Down
13 changes: 11 additions & 2 deletions offload/plugins-nextgen/common/include/PluginInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -815,6 +815,11 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
Error unloadBinary(DeviceImageTy *Image);
virtual Error unloadBinaryImpl(DeviceImageTy *Image) = 0;

/// Setup the device environment if needed. Notice this setup may not be run
/// on some plugins. By default, it will be executed, but plugins can change
/// this behavior by overriding the shouldSetupDeviceEnvironment function.
Error setupDeviceEnvironment(GenericPluginTy &Plugin, DeviceImageTy &Image);

/// Setup the global device memory pool, if the plugin requires one.
Error setupDeviceMemoryPool(GenericPluginTy &Plugin, DeviceImageTy &Image,
uint64_t PoolSize);
Expand Down Expand Up @@ -1014,7 +1019,6 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
uint32_t getDefaultNumBlocks() const {
return GridValues.GV_Default_Num_Teams;
}
uint32_t getDebugKind() const { return OMPX_DebugKind; }
uint32_t getDynamicMemorySize() const { return OMPX_SharedMemorySize; }
virtual uint64_t getClockFrequency() const { return CLOCKS_PER_SEC; }

Expand Down Expand Up @@ -1155,6 +1159,11 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
virtual Error getDeviceHeapSize(uint64_t &V) = 0;
virtual Error setDeviceHeapSize(uint64_t V) = 0;

/// Indicate whether the device should setup the device environment. Notice
/// that returning false in this function will change the behavior of the
/// setupDeviceEnvironment() function.
virtual bool shouldSetupDeviceEnvironment() const { return true; }

/// Indicate whether the device should setup the global device memory pool. If
/// false is return the value on the device will be uninitialized.
virtual bool shouldSetupDeviceMemoryPool() const { return true; }
Expand Down Expand Up @@ -1210,7 +1219,7 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
enum class PeerAccessState : uint8_t { AVAILABLE, UNAVAILABLE, PENDING };

/// Array of peer access states with the rest of devices. This means that if
/// the device I has a matrix PeerAccesses with PeerAccesses == AVAILABLE,
/// the device I has a matrix PeerAccesses with PeerAccesses[J] == AVAILABLE,
/// the device I can access device J's memory directly. However, notice this
/// does not mean that device J can access device I's memory directly.
llvm::SmallVector<PeerAccessState> PeerAccesses;
Expand Down
92 changes: 91 additions & 1 deletion offload/plugins-nextgen/common/src/PluginInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,54 @@ struct RecordReplayTy {
};
} // namespace llvm::omp::target::plugin

// Extract the mapping of host function pointers to device function pointers
// from the entry table. Functions marked as 'indirect' in OpenMP will have
// offloading entries generated for them which map the host's function pointer
// to a global containing the corresponding function pointer on the device.
static Expected<std::pair<void *, uint64_t>>
setupIndirectCallTable(GenericPluginTy &Plugin, GenericDeviceTy &Device,
DeviceImageTy &Image) {
GenericGlobalHandlerTy &Handler = Plugin.getGlobalHandler();

llvm::ArrayRef<llvm::offloading::EntryTy> Entries(
Image.getTgtImage()->EntriesBegin, Image.getTgtImage()->EntriesEnd);
llvm::SmallVector<std::pair<void *, void *>> IndirectCallTable;
for (const auto &Entry : Entries) {
if (Entry.Kind != object::OffloadKind::OFK_OpenMP || Entry.Size == 0 ||
!(Entry.Flags & OMP_DECLARE_TARGET_INDIRECT))
continue;

assert(Entry.Size == sizeof(void *) && "Global not a function pointer?");
auto &[HstPtr, DevPtr] = IndirectCallTable.emplace_back();

GlobalTy DeviceGlobal(Entry.SymbolName, Entry.Size);
if (auto Err =
Handler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal))
return std::move(Err);

HstPtr = Entry.Address;
if (auto Err = Device.dataRetrieve(&DevPtr, DeviceGlobal.getPtr(),
Entry.Size, nullptr))
return std::move(Err);
}

// If we do not have any indirect globals we exit early.
if (IndirectCallTable.empty())
return std::pair{nullptr, 0};

// Sort the array to allow for more efficient lookup of device pointers.
llvm::sort(IndirectCallTable,
[](const auto &x, const auto &y) { return x.first < y.first; });

uint64_t TableSize =
IndirectCallTable.size() * sizeof(std::pair<void *, void *>);
void *DevicePtr = Device.allocate(TableSize, nullptr, TARGET_ALLOC_DEVICE);
if (auto Err = Device.dataSubmit(DevicePtr, IndirectCallTable.data(),
TableSize, nullptr))
return std::move(Err);
return std::pair<void *, uint64_t>(DevicePtr, IndirectCallTable.size());
}

AsyncInfoWrapperTy::AsyncInfoWrapperTy(GenericDeviceTy &Device,
__tgt_async_info *AsyncInfoPtr)
: Device(Device),
Expand Down Expand Up @@ -881,6 +929,10 @@ Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
// Add the image to list.
LoadedImages.push_back(Image);

// Setup the device environment if needed.
if (auto Err = setupDeviceEnvironment(Plugin, *Image))
return std::move(Err);

// Setup the global device memory pool if needed.
if (!Plugin.getRecordReplay().isReplaying() &&
shouldSetupDeviceMemoryPool()) {
Expand Down Expand Up @@ -916,6 +968,43 @@ Expected<DeviceImageTy *> GenericDeviceTy::loadBinary(GenericPluginTy &Plugin,
return Image;
}

Error GenericDeviceTy::setupDeviceEnvironment(GenericPluginTy &Plugin,
DeviceImageTy &Image) {
// There are some plugins that do not need this step.
if (!shouldSetupDeviceEnvironment())
return Plugin::success();

// Obtain a table mapping host function pointers to device function pointers.
auto CallTablePairOrErr = setupIndirectCallTable(Plugin, *this, Image);
if (!CallTablePairOrErr)
return CallTablePairOrErr.takeError();

DeviceEnvironmentTy DeviceEnvironment;
DeviceEnvironment.DeviceDebugKind = OMPX_DebugKind;
DeviceEnvironment.NumDevices = Plugin.getNumDevices();
// TODO: The device ID used here is not the real device ID used by OpenMP.
DeviceEnvironment.DeviceNum = DeviceId;
DeviceEnvironment.DynamicMemSize = OMPX_SharedMemorySize;
DeviceEnvironment.ClockFrequency = getClockFrequency();
DeviceEnvironment.IndirectCallTable =
reinterpret_cast<uintptr_t>(CallTablePairOrErr->first);
DeviceEnvironment.IndirectCallTableSize = CallTablePairOrErr->second;
DeviceEnvironment.HardwareParallelism = getHardwareParallelism();

// Create the metainfo of the device environment global.
GlobalTy DevEnvGlobal("__omp_rtl_device_environment",
sizeof(DeviceEnvironmentTy), &DeviceEnvironment);

// Write device environment values to the device.
GenericGlobalHandlerTy &GHandler = Plugin.getGlobalHandler();
if (auto Err = GHandler.writeGlobalToDevice(*this, Image, DevEnvGlobal)) {
DP("Missing symbol %s, continue execution anyway.\n",
DevEnvGlobal.getName().data());
consumeError(std::move(Err));
}
return Plugin::success();
}

Error GenericDeviceTy::setupDeviceMemoryPool(GenericPluginTy &Plugin,
DeviceImageTy &Image,
uint64_t PoolSize) {
Expand Down Expand Up @@ -2158,7 +2247,8 @@ int32_t GenericPluginTy::get_global(__tgt_device_binary Binary, uint64_t Size,
GenericGlobalHandlerTy &GHandler = getGlobalHandler();
if (auto Err =
GHandler.getGlobalMetadataFromDevice(Device, Image, DeviceGlobal)) {
consumeError(std::move(Err));
REPORT("Failure to look up global address: %s\n",
toString(std::move(Err)).data());
return OFFLOAD_FAIL;
}

Expand Down
1 change: 1 addition & 0 deletions offload/plugins-nextgen/host/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
}

/// This plugin should not setup the device environment or memory pool.
virtual bool shouldSetupDeviceEnvironment() const override { return false; };
virtual bool shouldSetupDeviceMemoryPool() const override { return false; };

/// Getters and setters for stack size and heap size not relevant.
Expand Down
Loading