diff --git a/offload/liboffload/include/OffloadImpl.hpp b/offload/liboffload/include/OffloadImpl.hpp index 9b0a21cb9ae12..a12d8c47a180b 100644 --- a/offload/liboffload/include/OffloadImpl.hpp +++ b/offload/liboffload/include/OffloadImpl.hpp @@ -22,12 +22,12 @@ #include "llvm/ADT/StringSet.h" #include "llvm/Support/Error.h" -struct OffloadConfig { - bool TracingEnabled = false; - bool ValidationEnabled = true; -}; - -OffloadConfig &offloadConfig(); +namespace llvm { +namespace offload { +bool isTracingEnabled(); +bool isValidationEnabled(); +} // namespace offload +} // namespace llvm // Use the StringSet container to efficiently deduplicate repeated error // strings (e.g. if the same error is hit constantly in a long running program) diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index 770c212d804d2..f02497c0a6331 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -93,22 +93,36 @@ struct AllocInfo { ol_alloc_type_t Type; }; -using AllocInfoMapT = DenseMap; -AllocInfoMapT &allocInfoMap() { - static AllocInfoMapT AllocInfoMap{}; - return AllocInfoMap; -} +// Global shared state for liboffload +struct OffloadContext; +static OffloadContext *OffloadContextVal; +struct OffloadContext { + OffloadContext(OffloadContext &) = delete; + OffloadContext(OffloadContext &&) = delete; + OffloadContext &operator=(OffloadContext &) = delete; + OffloadContext &operator=(OffloadContext &&) = delete; + + bool TracingEnabled = false; + bool ValidationEnabled = true; + DenseMap AllocInfoMap{}; + SmallVector Platforms{}; + + ol_device_handle_t HostDevice() { + // The host platform is always inserted last + return &Platforms.back().Devices[0]; + } -using PlatformVecT = SmallVector; -PlatformVecT &Platforms() { - static PlatformVecT Platforms; - return Platforms; -} + static OffloadContext &get() { + assert(OffloadContextVal); + return *OffloadContextVal; + } +}; -ol_device_handle_t HostDevice() { - // The host platform is always inserted last - return &Platforms().back().Devices[0]; +// If the context is uninited, then we assume tracing is disabled +bool isTracingEnabled() { + return OffloadContextVal && OffloadContext::get().TracingEnabled; } +bool isValidationEnabled() { return OffloadContext::get().ValidationEnabled; } template Error olDestroy(HandleT Handle) { delete Handle; @@ -130,10 +144,12 @@ constexpr ol_platform_backend_t pluginNameToBackend(StringRef Name) { #include "Shared/Targets.def" void initPlugins() { + auto *Context = new OffloadContext{}; + // Attempt to create an instance of each supported plugin. #define PLUGIN_TARGET(Name) \ do { \ - Platforms().emplace_back(ol_platform_impl_t{ \ + Context->Platforms.emplace_back(ol_platform_impl_t{ \ std::unique_ptr(createPlugin_##Name()), \ {}, \ pluginNameToBackend(#Name)}); \ @@ -141,7 +157,7 @@ void initPlugins() { #include "Shared/Targets.def" // Preemptively initialize all devices in the plugin - for (auto &Platform : Platforms()) { + for (auto &Platform : Context->Platforms) { // Do not use the host plugin - it isn't supported. if (Platform.BackendType == OL_PLATFORM_BACKEND_UNKNOWN) continue; @@ -157,15 +173,16 @@ void initPlugins() { } // Add the special host device - auto &HostPlatform = Platforms().emplace_back( + auto &HostPlatform = Context->Platforms.emplace_back( ol_platform_impl_t{nullptr, {ol_device_impl_t{-1, nullptr, nullptr}}, OL_PLATFORM_BACKEND_HOST}); - HostDevice()->Platform = &HostPlatform; + Context->HostDevice()->Platform = &HostPlatform; + + Context->TracingEnabled = std::getenv("OFFLOAD_TRACE"); + Context->ValidationEnabled = !std::getenv("OFFLOAD_DISABLE_VALIDATION"); - offloadConfig().TracingEnabled = std::getenv("OFFLOAD_TRACE"); - offloadConfig().ValidationEnabled = - !std::getenv("OFFLOAD_DISABLE_VALIDATION"); + OffloadContextVal = Context; } // TODO: We can properly reference count here and manage the resources in a more @@ -229,7 +246,7 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, // Find the info if it exists under any of the given names auto GetInfo = [&](std::vector Names) { - if (Device == HostDevice()) + if (Device == OffloadContext::get().HostDevice()) return std::string("Host"); if (!Device->Device) @@ -251,8 +268,9 @@ Error olGetDeviceInfoImplDetail(ol_device_handle_t Device, case OL_DEVICE_INFO_PLATFORM: return ReturnValue(Device->Platform); case OL_DEVICE_INFO_TYPE: - return Device == HostDevice() ? ReturnValue(OL_DEVICE_TYPE_HOST) - : ReturnValue(OL_DEVICE_TYPE_GPU); + return Device == OffloadContext::get().HostDevice() + ? ReturnValue(OL_DEVICE_TYPE_HOST) + : ReturnValue(OL_DEVICE_TYPE_GPU); case OL_DEVICE_INFO_NAME: return ReturnValue(GetInfo({"Device Name"}).c_str()); case OL_DEVICE_INFO_VENDOR: @@ -280,7 +298,7 @@ Error olGetDeviceInfoSize_impl(ol_device_handle_t Device, } Error olIterateDevices_impl(ol_device_iterate_cb_t Callback, void *UserData) { - for (auto &Platform : Platforms()) { + for (auto &Platform : OffloadContext::get().Platforms) { for (auto &Device : Platform.Devices) { if (!Callback(&Device, UserData)) { break; @@ -311,16 +329,17 @@ Error olMemAlloc_impl(ol_device_handle_t Device, ol_alloc_type_t Type, return Alloc.takeError(); *AllocationOut = *Alloc; - allocInfoMap().insert_or_assign(*Alloc, AllocInfo{Device, Type}); + OffloadContext::get().AllocInfoMap.insert_or_assign(*Alloc, + AllocInfo{Device, Type}); return Error::success(); } Error olMemFree_impl(void *Address) { - if (!allocInfoMap().contains(Address)) + if (!OffloadContext::get().AllocInfoMap.contains(Address)) return createOffloadError(ErrorCode::INVALID_ARGUMENT, "address is not a known allocation"); - auto AllocInfo = allocInfoMap().at(Address); + auto AllocInfo = OffloadContext::get().AllocInfoMap.at(Address); auto Device = AllocInfo.Device; auto Type = AllocInfo.Type; @@ -328,7 +347,7 @@ Error olMemFree_impl(void *Address) { Device->Device->dataDelete(Address, convertOlToPluginAllocTy(Type))) return Res; - allocInfoMap().erase(Address); + OffloadContext::get().AllocInfoMap.erase(Address); return Error::success(); } @@ -395,7 +414,8 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, ol_device_handle_t DstDevice, const void *SrcPtr, ol_device_handle_t SrcDevice, size_t Size, ol_event_handle_t *EventOut) { - if (DstDevice == HostDevice() && SrcDevice == HostDevice()) { + auto Host = OffloadContext::get().HostDevice(); + if (DstDevice == Host && SrcDevice == Host) { if (!Queue) { std::memcpy(DstPtr, SrcPtr, Size); return Error::success(); @@ -410,11 +430,11 @@ Error olMemcpy_impl(ol_queue_handle_t Queue, void *DstPtr, // If no queue is given the memcpy will be synchronous auto QueueImpl = Queue ? Queue->AsyncInfo : nullptr; - if (DstDevice == HostDevice()) { + if (DstDevice == Host) { if (auto Res = SrcDevice->Device->dataRetrieve(DstPtr, SrcPtr, Size, QueueImpl)) return Res; - } else if (SrcDevice == HostDevice()) { + } else if (SrcDevice == Host) { if (auto Res = DstDevice->Device->dataSubmit(DstPtr, SrcPtr, Size, QueueImpl)) return Res; diff --git a/offload/liboffload/src/OffloadLib.cpp b/offload/liboffload/src/OffloadLib.cpp index 8662d3a44124b..0a65815e59698 100644 --- a/offload/liboffload/src/OffloadLib.cpp +++ b/offload/liboffload/src/OffloadLib.cpp @@ -30,11 +30,6 @@ ol_code_location_t *¤tCodeLocation() { return CodeLoc; } -OffloadConfig &offloadConfig() { - static OffloadConfig Config{}; - return Config; -} - namespace llvm { namespace offload { // Pull in the declarations for the implementation functions. The actual entry diff --git a/offload/tools/offload-tblgen/EntryPointGen.cpp b/offload/tools/offload-tblgen/EntryPointGen.cpp index 85c5c50bf2f20..13aa0d1f63187 100644 --- a/offload/tools/offload-tblgen/EntryPointGen.cpp +++ b/offload/tools/offload-tblgen/EntryPointGen.cpp @@ -35,21 +35,30 @@ static void EmitValidationFunc(const FunctionRec &F, raw_ostream &OS) { } OS << ") {\n"; - OS << TAB_1 "if (offloadConfig().ValidationEnabled) {\n"; - // Emit validation checks - for (const auto &Return : F.getReturns()) { - for (auto &Condition : Return.getConditions()) { - if (Condition.starts_with("`") && Condition.ends_with("`")) { - auto ConditionString = Condition.substr(1, Condition.size() - 2); - OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); - OS << formatv(TAB_3 "return createOffloadError(error::ErrorCode::{0}, " - "\"validation failure: {1}\");\n", - Return.getUnprefixedValue(), ConditionString); - OS << TAB_2 "}\n\n"; + bool HasValidation = llvm::any_of(F.getReturns(), [](auto &R) { + return llvm::any_of(R.getConditions(), [](auto &C) { + return C.starts_with("`") && C.ends_with("`"); + }); + }); + + if (HasValidation) { + OS << TAB_1 "if (llvm::offload::isValidationEnabled()) {\n"; + // Emit validation checks + for (const auto &Return : F.getReturns()) { + for (auto &Condition : Return.getConditions()) { + if (Condition.starts_with("`") && Condition.ends_with("`")) { + auto ConditionString = Condition.substr(1, Condition.size() - 2); + OS << formatv(TAB_2 "if ({0}) {{\n", ConditionString); + OS << formatv(TAB_3 + "return createOffloadError(error::ErrorCode::{0}, " + "\"validation failure: {1}\");\n", + Return.getUnprefixedValue(), ConditionString); + OS << TAB_2 "}\n\n"; + } } } + OS << TAB_1 "}\n\n"; } - OS << TAB_1 "}\n\n"; // Perform actual function call to the implementation ParamNameList = ParamNameList.substr(0, ParamNameList.size() - 2); @@ -74,7 +83,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { OS << ") {\n"; // Emit pre-call prints - OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n"; OS << formatv(TAB_2 "llvm::errs() << \"---> {0}\";\n", F.getName()); OS << TAB_1 "}\n\n"; @@ -85,7 +94,7 @@ static void EmitEntryPointFunc(const FunctionRec &F, raw_ostream &OS) { PrefixLower, F.getName(), ParamNameList); // Emit post-call prints - OS << TAB_1 "if (offloadConfig().TracingEnabled) {\n"; + OS << TAB_1 "if (llvm::offload::isTracingEnabled()) {\n"; if (F.getParams().size() > 0) { OS << formatv(TAB_2 "{0} Params = {{", F.getParamStructName()); for (const auto &Param : F.getParams()) {