Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions sycl/test-e2e/Regression/static-buffer-dtor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@
// UNSUPPORTED: windows && arch-intel_gpu_bmg_g21
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17255

// UNSUPPORTED: cuda
// UNSUPPORTED-TRACKER: https://github.com/intel/llvm/issues/17450

#include <sycl/detail/core.hpp>

int main() {
Expand Down
56 changes: 27 additions & 29 deletions unified-runtime/source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,16 @@
//
//===----------------------------------------------------------------------===//

#include <ur_api.h>

#include "adapter.hpp"
#include "common.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "tracing.hpp"

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
logger::Logger &logger;
ur_adapter_handle_t_();
};
#include <memory>

namespace ur::cuda {
std::shared_ptr<ur_adapter_handle_t_> adapter;
} // namespace ur::cuda

class ur_legacy_sink : public logger::Sink {
public:
Expand All @@ -43,28 +40,29 @@ class ur_legacy_sink : public logger::Sink {
ur_adapter_handle_t_::ur_adapter_handle_t_()
: logger(logger::get_logger("cuda",
/*default_log_level*/ logger::Level::ERR)) {
Platform = std::make_unique<ur_platform_handle_t_>();

if (std::getenv("UR_LOG_CUDA") != nullptr)
return;

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) {
if (std::getenv("UR_LOG_CUDA") == nullptr &&
(std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr)) {
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
}
}
ur_adapter_handle_t_ adapter{};

UR_APIEXPORT ur_result_t UR_APICALL
urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
uint32_t *pNumAdapters) {
if (NumEntries > 0 && phAdapters) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (adapter.RefCount++ == 0) {
adapter.TracingCtx = createCUDATracingContext();
enableCUDATracing(adapter.TracingCtx);
}
static std::once_flag InitFlag;
std::call_once(InitFlag, [=]() {
ur::cuda::adapter = std::make_shared<ur_adapter_handle_t_>();
});

std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
ur::cuda::adapter->TracingCtx = createCUDATracingContext();
enableCUDATracing(ur::cuda::adapter->TracingCtx);

*phAdapters = &adapter;
*phAdapters = ur::cuda::adapter.get();
}

if (pNumAdapters) {
Expand All @@ -75,17 +73,17 @@ urAdapterGet(uint32_t NumEntries, ur_adapter_handle_t *phAdapters,
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
adapter.RefCount++;
ur::cuda::adapter->RefCount++;

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
std::lock_guard<std::mutex> Lock{adapter.Mutex};
if (--adapter.RefCount == 0) {
disableCUDATracing(adapter.TracingCtx);
freeCUDATracingContext(adapter.TracingCtx);
adapter.TracingCtx = nullptr;
std::lock_guard<std::mutex> Lock{ur::cuda::adapter->Mutex};
if (--ur::cuda::adapter->RefCount == 0) {
disableCUDATracing(ur::cuda::adapter->TracingCtx);
freeCUDATracingContext(ur::cuda::adapter->TracingCtx);
ur::cuda::adapter.reset();
}
return UR_RESULT_SUCCESS;
}
Expand All @@ -108,7 +106,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_CUDA);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(adapter.RefCount.load());
return ReturnValue(ur::cuda::adapter->RefCount.load());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
28 changes: 26 additions & 2 deletions unified-runtime/source/adapters/cuda/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,30 @@
//
//===----------------------------------------------------------------------===//

struct ur_adapter_handle_t_;
#ifndef UR_CUDA_ADAPTER_HPP_INCLUDED
#define UR_CUDA_ADAPTER_HPP_INCLUDED

extern ur_adapter_handle_t_ adapter;
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "tracing.hpp"
#include <ur_api.h>

#include <atomic>
#include <memory>
#include <mutex>

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 1;
std::mutex Mutex;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
logger::Logger &logger;
std::unique_ptr<ur_platform_handle_t_> Platform;
ur_adapter_handle_t_();
};

// Keep the global namespace'd
namespace ur::cuda {
extern std::shared_ptr<ur_adapter_handle_t_> adapter;
} // namespace ur::cuda

#endif // UR_CUDA_ADAPTER_HPP_INCLUDED
1 change: 1 addition & 0 deletions unified-runtime/source/adapters/cuda/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "context.hpp"
#include "platform.hpp"
#include "usm.hpp"

#include <cassert>
Expand Down
16 changes: 8 additions & 8 deletions unified-runtime/source/adapters/cuda/context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
#pragma once

#include <cuda.h>
#include <memory>
#include <ur_api.h>

#include <atomic>
#include <mutex>
#include <set>
#include <vector>

#include "adapter.hpp"
#include "common.hpp"
#include "device.hpp"
#include "umf_helpers.hpp"
Expand Down Expand Up @@ -117,6 +119,10 @@ struct ur_context_handle_t_ {
void operator()() { Function(UserData); }
};

// Retain an additional reference to the adapter as it keeps the devices
// alive, which end up being used (indirectly) by our destructor.
std::shared_ptr<ur_adapter_handle_t_> Adapter;

std::vector<ur_device_handle_t> Devices;
std::atomic_uint32_t RefCount;

Expand All @@ -126,11 +132,8 @@ struct ur_context_handle_t_ {
umf_memory_pool_handle_t MemoryPoolHost = nullptr;

ur_context_handle_t_(const ur_device_handle_t *Devs, uint32_t NumDevices)
: Devices{Devs, Devs + NumDevices}, RefCount{1} {
for (auto &Dev : Devices) {
urDeviceRetain(Dev);
}

: Adapter(ur::cuda::adapter), Devices{Devs, Devs + NumDevices},
RefCount{1} {
// Create UMF CUDA memory provider for the host memory
// (UMF_MEMORY_TYPE_HOST) from any device (Devices[0] is used here, because
// it is guaranteed to exist).
Expand All @@ -145,9 +148,6 @@ struct ur_context_handle_t_ {
if (MemoryProviderHost) {
umfMemoryProviderDestroy(MemoryProviderHost);
}
for (auto &Dev : Devices) {
urDeviceRelease(Dev);
}
}

void invokeExtendedDeleters() {
Expand Down
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceCreateWithNativeHandle(

// Get list of platforms
uint32_t NumPlatforms = 0;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
ur_result_t Result =
urPlatformGet(&AdapterHandle, 1, 0, nullptr, &NumPlatforms);
if (Result != UR_RESULT_SUCCESS)
Expand Down
3 changes: 3 additions & 0 deletions unified-runtime/source/adapters/cuda/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,9 @@ struct ur_mem_handle_t_ {
urMemRelease(std::get<BufferMem>(Mem).Parent);
return;
}
if (LastQueueWritingToMemObj != nullptr) {
urQueueRelease(LastQueueWritingToMemObj);
}
urContextRelease(Context);
}

Expand Down
18 changes: 9 additions & 9 deletions unified-runtime/source/adapters/cuda/platform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
return ReturnValue(UR_PLATFORM_BACKEND_CUDA);
}
case UR_PLATFORM_INFO_ADAPTER: {
return ReturnValue(&adapter);
return ReturnValue(ur::cuda::adapter.get());
}
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
Expand All @@ -116,11 +116,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urPlatformGetInfo(
UR_APIEXPORT ur_result_t UR_APICALL
urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
ur_platform_handle_t *phPlatforms, uint32_t *pNumPlatforms) {

try {
static std::once_flag InitFlag;
static uint32_t NumPlatforms = 1;
static ur_platform_handle_t_ Platform;

UR_ASSERT(phPlatforms || pNumPlatforms, UR_RESULT_ERROR_INVALID_VALUE);
UR_ASSERT(!phPlatforms || NumEntries > 0, UR_RESULT_ERROR_INVALID_SIZE);
Expand Down Expand Up @@ -151,22 +149,24 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
// Use default stream to record base event counter
UR_CHECK_ERROR(cuEventRecord(EvBase, 0));

Platform.Devices.emplace_back(
new ur_device_handle_t_{Device, Context, EvBase, &Platform,
ur::cuda::adapter->Platform->Devices.emplace_back(
new ur_device_handle_t_{Device, Context, EvBase,
ur::cuda::adapter->Platform.get(),
static_cast<uint32_t>(i)});
}

UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(&Platform));
UR_CHECK_ERROR(CreateDeviceMemoryProvidersPools(
ur::cuda::adapter->Platform.get()));
} catch (const std::bad_alloc &) {
// Signal out-of-memory situation
for (int i = 0; i < NumDevices; ++i) {
Platform.Devices.clear();
ur::cuda::adapter->Platform->Devices.clear();
}
Result = UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (ur_result_t Err) {
// Clear and rethrow to allow retry
for (int i = 0; i < NumDevices; ++i) {
Platform.Devices.clear();
ur::cuda::adapter->Platform->Devices.clear();
}
Result = Err;
throw Err;
Expand All @@ -182,7 +182,7 @@ urPlatformGet(ur_adapter_handle_t *, uint32_t, uint32_t NumEntries,
}

if (phPlatforms != nullptr) {
*phPlatforms = &Platform;
*phPlatforms = ur::cuda::adapter->Platform.get();
}

return Result;
Expand Down
9 changes: 8 additions & 1 deletion unified-runtime/source/adapters/cuda/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,18 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#pragma once

#ifndef UR_CUDA_PLATFORM_HPP_INCLUDED
#define UR_CUDA_PLATFORM_HPP_INCLUDED

#include "device.hpp"
#include <ur/ur.hpp>

#include <memory>
#include <vector>

struct ur_platform_handle_t_ {
std::vector<std::unique_ptr<ur_device_handle_t_>> Devices;
};

#endif // UR_CUDA_PLATFORM_HPP_INCLUDED
2 changes: 1 addition & 1 deletion unified-runtime/source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,

// cuda backend has only one platform containing all devices
ur_platform_handle_t platform;
ur_adapter_handle_t AdapterHandle = &adapter;
ur_adapter_handle_t AdapterHandle = ur::cuda::adapter.get();
Result = urPlatformGet(&AdapterHandle, 1, 1, &platform, nullptr);

// get the device from the platform
Expand Down
31 changes: 16 additions & 15 deletions unified-runtime/source/adapters/hip/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,14 @@

#include "adapter.hpp"
#include "common.hpp"
#include "logger/ur_logger.hpp"

#include <atomic>
#include <ur_api.h>

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
logger::Logger &logger;
ur_adapter_handle_t_();
};
#include <memory>

namespace ur::hip {
std::shared_ptr<ur_adapter_handle_t_> adapter;
}

class ur_legacy_sink : public logger::Sink {
public:
Expand All @@ -42,7 +40,7 @@ class ur_legacy_sink : public logger::Sink {
ur_adapter_handle_t_::ur_adapter_handle_t_()
: logger(
logger::get_logger("hip", /*default_log_level*/ logger::Level::ERR)) {

Platform = std::make_unique<ur_platform_handle_t_>();
if (std::getenv("UR_LOG_HIP") != nullptr)
return;

Expand All @@ -52,13 +50,16 @@ ur_adapter_handle_t_::ur_adapter_handle_t_()
}
}

ur_adapter_handle_t_ adapter{};

UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(
uint32_t, ur_adapter_handle_t *phAdapters, uint32_t *pNumAdapters) {
if (phAdapters) {
adapter.RefCount++;
*phAdapters = &adapter;
static std::once_flag InitFlag;
std::call_once(InitFlag, [=]() {
ur::hip::adapter = std::make_shared<ur_adapter_handle_t_>();
});

ur::hip::adapter->RefCount++;
*phAdapters = ur::hip::adapter.get();
}
if (pNumAdapters) {
*pNumAdapters = 1;
Expand All @@ -69,12 +70,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGet(

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRelease(ur_adapter_handle_t) {
// No state to clean up so we don't need to check for 0 references
adapter.RefCount--;
ur::hip::adapter->RefCount--;
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urAdapterRetain(ur_adapter_handle_t) {
adapter.RefCount++;
ur::hip::adapter->RefCount++;
return UR_RESULT_SUCCESS;
}

Expand All @@ -96,7 +97,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urAdapterGetInfo(ur_adapter_handle_t,
case UR_ADAPTER_INFO_BACKEND:
return ReturnValue(UR_ADAPTER_BACKEND_HIP);
case UR_ADAPTER_INFO_REFERENCE_COUNT:
return ReturnValue(adapter.RefCount.load());
return ReturnValue(ur::hip::adapter->RefCount.load());
case UR_ADAPTER_INFO_VERSION:
return ReturnValue(uint32_t{1});
default:
Expand Down
Loading
Loading