Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion offload/liboffload/API/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def : Function {
let desc = "Release the resources in use by Offload";
let details = [
"This decrements an internal reference count. When this reaches 0, all resources will be released",
"Subsequent API calls made after this are not valid"
"Subsequent API calls to methods other than `olInit` made after resources are released will return OL_ERRC_UNINITIALIZED"
];
let params = [];
let returns = [];
Expand Down
69 changes: 48 additions & 21 deletions offload/liboffload/src/OffloadImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,10 @@ struct AllocInfo {

// Global shared state for liboffload
struct OffloadContext;
static OffloadContext *OffloadContextVal;
// This pointer is non-null if and only if the context is valid and fully
// initialized
static std::atomic<OffloadContext *> OffloadContextVal;
std::mutex OffloadContextValMutex;
struct OffloadContext {
OffloadContext(OffloadContext &) = delete;
OffloadContext(OffloadContext &&) = delete;
Expand All @@ -107,6 +110,7 @@ struct OffloadContext {
bool ValidationEnabled = true;
DenseMap<void *, AllocInfo> AllocInfoMap{};
SmallVector<ol_platform_impl_t, 4> Platforms{};
size_t RefCount;

ol_device_handle_t HostDevice() {
// The host platform is always inserted last
Expand Down Expand Up @@ -145,20 +149,18 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) {
#define PLUGIN_TARGET(Name) extern "C" GenericPluginTy *createPlugin_##Name();
#include "Shared/Targets.def"

Error initPlugins() {
auto *Context = new OffloadContext{};

Error initPlugins(OffloadContext &Context) {
// Attempt to create an instance of each supported plugin.
#define PLUGIN_TARGET(Name) \
do { \
Context->Platforms.emplace_back(ol_platform_impl_t{ \
Context.Platforms.emplace_back(ol_platform_impl_t{ \
std::unique_ptr<GenericPluginTy>(createPlugin_##Name()), \
pluginNameToBackend(#Name)}); \
} while (false);
#include "Shared/Targets.def"

// Preemptively initialize all devices in the plugin
for (auto &Platform : Context->Platforms) {
for (auto &Platform : Context.Platforms) {
// Do not use the host plugin - it isn't supported.
if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN)
continue;
Expand All @@ -178,31 +180,56 @@ Error initPlugins() {
}

// Add the special host device
auto &HostPlatform = Context->Platforms.emplace_back(
auto &HostPlatform = Context.Platforms.emplace_back(
ol_platform_impl_t{nullptr, OL_PLATFORM_BACKEND_HOST});
HostPlatform.Devices.emplace_back(-1, nullptr, nullptr, InfoTreeNode{});
Context->HostDevice()->Platform = &HostPlatform;

Context->TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");
Context.HostDevice()->Platform = &HostPlatform;

OffloadContextVal = Context;
Context.TracingEnabled = std::getenv("OFFLOAD_TRACE");
Context.ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION");

return Plugin::success();
}

// TODO: We can properly reference count here and manage the resources in a more
// clever way
Error olInit_impl() {
static std::once_flag InitFlag;
std::optional<Error> InitResult{};
std::call_once(InitFlag, [&] { InitResult = initPlugins(); });
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};

if (InitResult)
return std::move(*InitResult);
return Error::success();
if (isOffloadInitialized()) {
OffloadContext::get().RefCount++;
return Plugin::success();
}

// Use a temporary to ensure that entry points querying OffloadContextVal do
// not get a partially initialized context
auto *NewContext = new OffloadContext{};
Error InitResult = initPlugins(*NewContext);
OffloadContextVal.store(NewContext);
OffloadContext::get().RefCount++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we be calling new and delete on this if it's null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The context is new'd in initPlugins(), but perhaps it makes sense doing it in olInit_impl, so I'll move it here.


return InitResult;
}

Error olShutDown_impl() {
std::lock_guard<std::mutex> Lock{OffloadContextValMutex};

if (--OffloadContext::get().RefCount != 0)
return Error::success();

llvm::Error Result = Error::success();
auto *OldContext = OffloadContextVal.exchange(nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, forgot to mention that this needed to be atomic since we're using it outside the mutex as a guard.


for (auto &P : OldContext->Platforms) {
// Host plugin is nullptr and has no deinit
if (!P.Plugin)
continue;
Comment on lines +222 to +224
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to handle this more cleanly in the future.


if (auto Res = P.Plugin->deinit())
Result = llvm::joinErrors(std::move(Result), std::move(Res));
Comment on lines +226 to +227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we have multiple plugins active this will potentially drop a previous error and hit an assertion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should collect all the errors from each plugin into a list. But I don't think that is actually handled correctly when returning it from the C api.

}

delete OldContext;
return Result;
}
Error olShutDown_impl() { return Error::success(); }

Error olGetPlatformInfoImplDetail(ol_platform_handle_t Platform,
ol_platform_info_t PropName, size_t PropSize,
Expand Down
12 changes: 12 additions & 0 deletions offload/unittests/OffloadAPI/init/olInit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,20 @@

struct olInitTest : ::testing::Test {};

TEST_F(olInitTest, Success) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}

TEST_F(olInitTest, Uninitialized) {
ASSERT_ERROR(OL_ERRC_UNINITIALIZED,
olIterateDevices(
[](ol_device_handle_t, void *) { return false; }, nullptr));
}

TEST_F(olInitTest, RepeatedInit) {
for (size_t I = 0; I < 10; I++) {
ASSERT_SUCCESS(olInit());
ASSERT_SUCCESS(olShutDown());
}
}
Loading