Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 33 additions & 42 deletions unified-runtime/source/adapters/opencl/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
#include <dlfcn.h>
#endif

// There can only be one OpenCL adapter alive at a time.
// If it is alive (more get/retains than releases called), this is a pointer to
// it.
static ur_adapter_handle_t liveAdapter = nullptr;

ur_adapter_handle_t_::ur_adapter_handle_t_() {
#ifdef _MSC_VER

Expand All @@ -42,45 +47,38 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
#undef CL_CORE_FUNCTION

#endif // _MSC_VER
assert(!liveAdapter);
liveAdapter = this;
}

static ur_adapter_handle_t adapter = nullptr;
ur_adapter_handle_t_::~ur_adapter_handle_t_() {
assert(liveAdapter == this);
liveAdapter = nullptr;
}

ur_adapter_handle_t ur::cl::getAdapter() {
if (!adapter) {
if (!liveAdapter) {
die("OpenCL adapter used before initalization or after destruction");
}
return adapter;
}

static void globalAdapterShutdown() {
if (cl_ext::ExtFuncPtrCache) {
delete cl_ext::ExtFuncPtrCache;
cl_ext::ExtFuncPtrCache = nullptr;
}
if (adapter) {
delete adapter;
adapter = nullptr;
}
return liveAdapter;
}

UR_APIEXPORT ur_result_t UR_APICALL
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
uint32_t *pNumAdapters) {
static std::mutex AdapterConstructionMutex{};

if (NumEntries > 0 && phAdapters) {
// Sometimes urAdaterGet may be called after the library already been torn
// down, we also need to create a temporary handle for it.
if (!adapter) {
adapter = new ur_adapter_handle_t_();
atexit(globalAdapterShutdown);
}
std::lock_guard<std::mutex> Lock{AdapterConstructionMutex};

std::lock_guard<std::mutex> Lock{adapter->Mutex};
if (adapter->RefCount++ == 0) {
cl_ext::ExtFuncPtrCache = new cl_ext::ExtFuncPtrCacheT();
if (!liveAdapter) {
*phAdapters = new ur_adapter_handle_t_();
} else {
*phAdapters = liveAdapter;
}

*phAdapters = adapter;
auto &adapter = *phAdapters;
adapter->RefCount++;
}

if (pNumAdapters) {
Expand All @@ -90,21 +88,16 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
++adapter->RefCount;
UR_APIEXPORT ur_result_t UR_APICALL
urAdapterRetain(ur_adapter_handle_t hAdapter) {
++hAdapter->RefCount;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
// Check first if the adapter is valid pointer
if (adapter) {
std::lock_guard<std::mutex> Lock{adapter->Mutex};
if (--adapter->RefCount == 0) {
if (cl_ext::ExtFuncPtrCache) {
delete cl_ext::ExtFuncPtrCache;
cl_ext::ExtFuncPtrCache = nullptr;
}
}
UR_APIEXPORT ur_result_t UR_APICALL
urAdapterRelease(ur_adapter_handle_t hAdapter) {
if (--hAdapter->RefCount == 0) {
delete hAdapter;
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -117,18 +110,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetLastError(
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
ur_adapter_info_t propName,
size_t propSize,
void *pPropValue,
size_t *pPropSizeRet) {
UR_APIEXPORT ur_result_t UR_APICALL
urAdapterGetInfo(ur_adapter_handle_t hAdapter, ur_adapter_info_t propName,
size_t propSize, void *pPropValue, size_t *pPropSizeRet) {
UrReturnHelper ReturnValue(propSize, pPropValue, pPropSizeRet);

switch (propName) {
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_OPENCL);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(adapter->RefCount.load());
return ReturnValue(hAdapter->RefCount.load());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
8 changes: 7 additions & 1 deletion unified-runtime/source/adapters/opencl/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,25 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#include "device.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"

#include "CL/cl.h"
#include "common.hpp"
#include "logger/ur_logger.hpp"

struct ur_adapter_handle_t_ {
ur_adapter_handle_t_();
~ur_adapter_handle_t_();

ur_adapter_handle_t_(ur_adapter_handle_t_ &) = delete;

std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
logger::Logger &log = logger::get_logger("opencl");
cl_ext::ExtFuncPtrCacheT fnCache{};

std::vector<std::unique_ptr<ur_platform_handle_t_>> URPlatforms;
uint32_t NumPlatforms = 0;
Expand Down
28 changes: 18 additions & 10 deletions unified-runtime/source/adapters/opencl/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "command_buffer.hpp"
#include "adapter.hpp"
#include "common.hpp"
#include "context.hpp"
#include "event.hpp"
Expand All @@ -25,7 +26,8 @@ ur_exp_command_buffer_handle_t_::~ur_exp_command_buffer_handle_t_() {
cl_ext::clReleaseCommandBufferKHR_fn clReleaseCommandBufferKHR = nullptr;
cl_int Res =
cl_ext::getExtFuncFromContext<decltype(clReleaseCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clReleaseCommandBufferKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clReleaseCommandBufferKHRCache,
cl_ext::ReleaseCommandBufferName, &clReleaseCommandBufferKHR);
assert(Res == CL_SUCCESS);
(void)Res;
Expand All @@ -42,7 +44,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferCreateExp(
cl_ext::clCreateCommandBufferKHR_fn clCreateCommandBufferKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCreateCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCreateCommandBufferKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clCreateCommandBufferKHRCache,
cl_ext::CreateCommandBufferName, &clCreateCommandBufferKHR));

const bool IsUpdatable = pCommandBufferDesc->isUpdatable;
Expand Down Expand Up @@ -116,7 +119,8 @@ urCommandBufferFinalizeExp(ur_exp_command_buffer_handle_t hCommandBuffer) {
cl_ext::clFinalizeCommandBufferKHR_fn clFinalizeCommandBufferKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clFinalizeCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clFinalizeCommandBufferKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clFinalizeCommandBufferKHRCache,
cl_ext::FinalizeCommandBufferName, &clFinalizeCommandBufferKHR));

CL_RETURN_ON_FAILURE(
Expand Down Expand Up @@ -148,7 +152,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
cl_ext::clCommandNDRangeKernelKHR_fn clCommandNDRangeKernelKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCommandNDRangeKernelKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandNDRangeKernelKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clCommandNDRangeKernelKHRCache,
cl_ext::CommandNRRangeKernelName, &clCommandNDRangeKernelKHR));

cl_mutable_command_khr CommandHandle = nullptr;
Expand Down Expand Up @@ -238,7 +243,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyExp(
cl_ext::clCommandCopyBufferKHR_fn clCommandCopyBufferKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferKHRCache,
CLContext, ur::cl::getAdapter()->fnCache.clCommandCopyBufferKHRCache,
cl_ext::CommandCopyBufferName, &clCommandCopyBufferKHR));

const bool IsInOrder = hCommandBuffer->IsInOrder;
Expand Down Expand Up @@ -280,7 +285,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferCopyRectExp(
cl_ext::clCommandCopyBufferRectKHR_fn clCommandCopyBufferRectKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCommandCopyBufferRectKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandCopyBufferRectKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clCommandCopyBufferRectKHRCache,
cl_ext::CommandCopyBufferRectName, &clCommandCopyBufferRectKHR));

const bool IsInOrder = hCommandBuffer->IsInOrder;
Expand Down Expand Up @@ -388,7 +394,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendMemBufferFillExp(
cl_ext::clCommandFillBufferKHR_fn clCommandFillBufferKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCommandFillBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clCommandFillBufferKHRCache,
CLContext, ur::cl::getAdapter()->fnCache.clCommandFillBufferKHRCache,
cl_ext::CommandFillBufferName, &clCommandFillBufferKHR));

const bool IsInOrder = hCommandBuffer->IsInOrder;
Expand Down Expand Up @@ -459,7 +465,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueCommandBufferExp(
cl_ext::clEnqueueCommandBufferKHR_fn clEnqueueCommandBufferKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clEnqueueCommandBufferKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueCommandBufferKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clEnqueueCommandBufferKHRCache,
cl_ext::EnqueueCommandBufferName, &clEnqueueCommandBufferKHR));

const uint32_t NumberOfQueues = 1;
Expand Down Expand Up @@ -618,7 +625,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
cl_ext::clUpdateMutableCommandsKHR_fn clUpdateMutableCommandsKHR = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clUpdateMutableCommandsKHR)>(
CLContext, cl_ext::ExtFuncPtrCache->clUpdateMutableCommandsKHRCache,
CLContext,
ur::cl::getAdapter()->fnCache.clUpdateMutableCommandsKHRCache,
cl_ext::UpdateMutableCommandsName, &clUpdateMutableCommandsKHR));

std::vector<cl_mutable_dispatch_config_khr> ConfigList(numKernelUpdates);
Expand Down Expand Up @@ -754,7 +762,7 @@ ur_result_t UR_APICALL urCommandBufferAppendNativeCommandExp(
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<decltype(clCommandBarrierWithWaitListKHR)>(
CLContext,
cl_ext::ExtFuncPtrCache->clCommandBarrierWithWaitListKHRCache,
ur::cl::getAdapter()->fnCache.clCommandBarrierWithWaitListKHRCache,
cl_ext::CommandBarrierWithWaitListName,
&clCommandBarrierWithWaitListKHR));

Expand Down
5 changes: 0 additions & 5 deletions unified-runtime/source/adapters/opencl/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,6 @@ struct ExtFuncPtrCacheT {
#undef CL_EXTENSION_FUNC
}
};
// A raw pointer is used here since the lifetime of this map has to be tied to
// piTeardown to avoid issues with static destruction order (a user application
// might have static objects that indirectly access this cache in their
// destructor).
inline ExtFuncPtrCacheT *ExtFuncPtrCache;

// USM helper function to get an extension function pointer
template <typename T>
Expand Down
10 changes: 0 additions & 10 deletions unified-runtime/source/adapters/opencl/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,20 +117,10 @@ urContextGetInfo(ur_context_handle_t hContext, ur_context_info_t propName,

UR_APIEXPORT ur_result_t UR_APICALL
urContextRelease(ur_context_handle_t hContext) {
// If we're reasonably sure this context is about to be detroyed we should
// clear the ext function pointer cache. This isn't foolproof sadly but it
// should drastically reduce the chances of the pathological case described
// in the comments in common.hpp.
static std::mutex contextReleaseMutex;
auto clContext = hContext->CLContext;

std::lock_guard<std::mutex> lock(contextReleaseMutex);
if (hContext->decrementReferenceCount() == 0) {
// ExtFuncPtrCache is destroyed in an atexit() callback, so it doesn't
// necessarily outlive the adapter (or all the contexts).
if (cl_ext::ExtFuncPtrCache) {
cl_ext::ExtFuncPtrCache->clearCache(clContext);
}
delete hContext;
}

Expand Down
11 changes: 11 additions & 0 deletions unified-runtime/source/adapters/opencl/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//
#pragma once

#include "adapter.hpp"
#include "common.hpp"
#include "device.hpp"

Expand All @@ -29,6 +30,9 @@ struct ur_context_handle_t_ {
Devices.emplace_back(phDevices[i]);
urDeviceRetain(phDevices[i]);
}
// The context retains a reference to the adapter so it can clear the
// function ptr cache on destruction
urAdapterRetain(ur::cl::getAdapter());
RefCount = 1;
}

Expand All @@ -42,6 +46,13 @@ struct ur_context_handle_t_ {
const ur_device_handle_t *phDevices,
ur_context_handle_t &Context);
~ur_context_handle_t_() {
// If we're reasonably sure this context is about to be destroyed we should
// clear the ext function pointer cache. This isn't foolproof sadly but it
// should drastically reduce the chances of the pathological case described
// in the comments in common.hpp.
ur::cl::getAdapter()->fnCache.clearCache(CLContext);
urAdapterRelease(ur::cl::getAdapter());

for (uint32_t i = 0; i < DeviceCount; i++) {
urDeviceRelease(Devices[i]);
}
Expand Down
11 changes: 7 additions & 4 deletions unified-runtime/source/adapters/opencl/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//
//===----------------------------------------------------------------------===//

#include "adapter.hpp"
#include "common.hpp"
#include "context.hpp"
#include "event.hpp"
Expand Down Expand Up @@ -400,7 +401,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableWrite(
MapUREventsToCL(numEventsInWaitList, phEventWaitList, CLWaitEvents);
cl_ext::clEnqueueWriteGlobalVariable_fn F = nullptr;
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<decltype(F)>(
Ctx, cl_ext::ExtFuncPtrCache->clEnqueueWriteGlobalVariableCache,
Ctx, ur::cl::getAdapter()->fnCache.clEnqueueWriteGlobalVariableCache,
cl_ext::EnqueueWriteGlobalVariableName, &F));

cl_int Res = F(hQueue->CLQueue, hProgram->CLProgram, name, blockingWrite,
Expand All @@ -422,7 +423,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueDeviceGlobalVariableRead(
MapUREventsToCL(numEventsInWaitList, phEventWaitList, CLWaitEvents);
cl_ext::clEnqueueReadGlobalVariable_fn F = nullptr;
UR_RETURN_ON_FAILURE(cl_ext::getExtFuncFromContext<decltype(F)>(
Ctx, cl_ext::ExtFuncPtrCache->clEnqueueReadGlobalVariableCache,
Ctx, ur::cl::getAdapter()->fnCache.clEnqueueReadGlobalVariableCache,
cl_ext::EnqueueReadGlobalVariableName, &F));

cl_int Res = F(hQueue->CLQueue, hProgram->CLProgram, name, blockingRead,
Expand All @@ -446,7 +447,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueReadHostPipe(
cl_ext::clEnqueueReadHostPipeINTEL_fn FuncPtr = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueReadHostPipeINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueReadHostPipeINTELCache,
CLContext,
ur::cl::getAdapter()->fnCache.clEnqueueReadHostPipeINTELCache,
cl_ext::EnqueueReadHostPipeName, &FuncPtr));

if (FuncPtr) {
Expand Down Expand Up @@ -474,7 +476,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
cl_ext::clEnqueueWriteHostPipeINTEL_fn FuncPtr = nullptr;
UR_RETURN_ON_FAILURE(
cl_ext::getExtFuncFromContext<cl_ext::clEnqueueWriteHostPipeINTEL_fn>(
CLContext, cl_ext::ExtFuncPtrCache->clEnqueueWriteHostPipeINTELCache,
CLContext,
ur::cl::getAdapter()->fnCache.clEnqueueWriteHostPipeINTELCache,
cl_ext::EnqueueWriteHostPipeName, &FuncPtr));

if (FuncPtr) {
Expand Down
Loading