diff --git a/offload/liboffload/API/APIDefs.td b/offload/liboffload/API/APIDefs.td index 640932dcf8464..bd4cbbaa546b2 100644 --- a/offload/liboffload/API/APIDefs.td +++ b/offload/liboffload/API/APIDefs.td @@ -31,6 +31,13 @@ class IsHandleType { !ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1)); } +// Does the type end with '_cb_t'? +class IsCallbackType { + // size("_cb_t") == 5 + bit ret = !if(!lt(!size(Type), 5), 0, + !ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1)); +} + // Does the type end with '*'? class IsPointerType { bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1); @@ -58,6 +65,7 @@ class Param Flags = 0> { TypeInfo type_info = TypeInfo<"", "">; bit IsHandle = IsHandleType.ret; bit IsPointer = IsPointerType.ret; + bit IsCallback = IsCallbackType.ret; } // A parameter whose range is described by other parameters in the function. @@ -81,7 +89,7 @@ class ShouldCheckHandle { } class ShouldCheckPointer { - bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0)); + bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0)); } // For a list of returns that contains a specific return code, find and append diff --git a/offload/liboffload/API/Queue.td b/offload/liboffload/API/Queue.td index 1d9f6f2d11c9b..0e20e23999d5e 100644 --- a/offload/liboffload/API/Queue.td +++ b/offload/liboffload/API/Queue.td @@ -108,3 +108,29 @@ def : Function { Return<"OL_ERRC_INVALID_QUEUE"> ]; } + +def : FptrTypedef { + let name = "ol_host_function_cb_t"; + let desc = "Host function for use by `olLaunchHostFunction`."; + let params = [ + Param<"void *", "UserData", "user specified data passed into `olLaunchHostFunction`.", PARAM_IN>, + ]; + let return = "void"; +} + +def : Function { + let name = "olLaunchHostFunction"; + let desc = "Enqueue a callback function on the host."; + let details = [ + "The provided function will be called from the same process as the one that called `olLaunchHostFunction`.", + "The callback will not run until all previous work submitted to the queue has completed.", + "The callback must return before any work submitted to the queue after it is started.", + "The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).", + ]; + let params = [ + Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>, + Param<"ol_host_function_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>, + Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>, + ]; + let returns = []; +} diff --git a/offload/liboffload/src/OffloadImpl.cpp b/offload/liboffload/src/OffloadImpl.cpp index f5365ca274308..ccabf5fc0e799 100644 --- a/offload/liboffload/src/OffloadImpl.cpp +++ b/offload/liboffload/src/OffloadImpl.cpp @@ -833,5 +833,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol, return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet); } +Error olLaunchHostFunction_impl(ol_queue_handle_t Queue, + ol_host_function_cb_t Callback, + void *UserData) { + return Queue->Device->Device->enqueueHostCall(Callback, UserData, + Queue->AsyncInfo); +} + } // namespace offload } // namespace llvm diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index 796182075ff3d..536c662451dfd 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy { /// Indicate to spread data transfers across all available SDMAs bool UseMultipleSdmaEngines; + /// Wrapper function for implementing host callbacks + static void CallbackWrapper(AMDGPUSignalTy *InputSignal, + AMDGPUSignalTy *OutputSignal, + void (*Callback)(void *), void *UserData) { + // The wait call will not error in this context. + if (InputSignal) + if (auto Err = InputSignal->wait()) + reportFatalInternalError(std::move(Err)); + + Callback(UserData); + + OutputSignal->signal(); + } + /// Return the current number of asynchronous operations on the stream. uint32_t size() const { return NextSlot; } @@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy { OutputSignal->get()); } + Error pushHostCallback(void (*Callback)(void *), void *UserData) { + // Retrieve an available signal for the operation's output. + AMDGPUSignalTy *OutputSignal = nullptr; + if (auto Err = SignalManager.getResource(OutputSignal)) + return Err; + OutputSignal->reset(); + OutputSignal->increaseUseCount(); + + AMDGPUSignalTy *InputSignal; + { + std::lock_guard Lock(Mutex); + + // Consume stream slot and compute dependencies. + InputSignal = consume(OutputSignal).second; + } + + // "Leaking" the thread here is consistent with other work added to the + // queue. The input and output signals will remain valid until the output is + // signaled. + std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData) + .detach(); + + return Plugin::success(); + } + /// Synchronize with the stream. The current thread waits until all operations /// are finalized and it performs the pending post actions (i.e., releasing /// intermediate buffers). @@ -2553,6 +2592,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy { return Plugin::success(); } + Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + AsyncInfoWrapperTy &AsyncInfo) override { + AMDGPUStreamTy *Stream = nullptr; + if (auto Err = getStream(AsyncInfo, Stream)) + return Err; + + return Stream->pushHostCallback(Callback, UserData); + }; + /// Create an event. Error createEventImpl(void **EventPtrStorage) override { AMDGPUEventTy **Event = reinterpret_cast(EventPtrStorage); diff --git a/offload/plugins-nextgen/common/include/PluginInterface.h b/offload/plugins-nextgen/common/include/PluginInterface.h index c9ab34b024b77..5e32a1a76d966 100644 --- a/offload/plugins-nextgen/common/include/PluginInterface.h +++ b/offload/plugins-nextgen/common/include/PluginInterface.h @@ -965,6 +965,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy { Error initDeviceInfo(__tgt_device_info *DeviceInfo); virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0; + /// Enqueue a host call to AsyncInfo + Error enqueueHostCall(void (*Callback)(void *), void *UserData, + __tgt_async_info *AsyncInfo); + virtual Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + AsyncInfoWrapperTy &AsyncInfo) = 0; + /// Create an event. Error createEvent(void **EventPtrStorage); virtual Error createEventImpl(void **EventPtrStorage) = 0; diff --git a/offload/plugins-nextgen/common/src/PluginInterface.cpp b/offload/plugins-nextgen/common/src/PluginInterface.cpp index 083d41659a469..f177c5bc9f487 100644 --- a/offload/plugins-nextgen/common/src/PluginInterface.cpp +++ b/offload/plugins-nextgen/common/src/PluginInterface.cpp @@ -1589,6 +1589,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) { return Err; } +Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData, + __tgt_async_info *AsyncInfo) { + AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo); + + auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper); + AsyncInfoWrapper.finalize(Err); + return Err; +} + Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) { assert(DeviceInfo && "Invalid device info"); diff --git a/offload/plugins-nextgen/cuda/src/rtl.cpp b/offload/plugins-nextgen/cuda/src/rtl.cpp index e94f3f6af7dd4..5e1843c045534 100644 --- a/offload/plugins-nextgen/cuda/src/rtl.cpp +++ b/offload/plugins-nextgen/cuda/src/rtl.cpp @@ -873,6 +873,19 @@ struct CUDADeviceTy : public GenericDeviceTy { return Plugin::success(); } + Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + AsyncInfoWrapperTy &AsyncInfo) override { + if (auto Err = setContext()) + return Err; + + CUstream Stream; + if (auto Err = getStream(AsyncInfo, Stream)) + return Err; + + CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData); + return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s"); + }; + /// Create an event. Error createEventImpl(void **EventPtrStorage) override { CUevent *Event = reinterpret_cast(EventPtrStorage); diff --git a/offload/plugins-nextgen/host/src/rtl.cpp b/offload/plugins-nextgen/host/src/rtl.cpp index ed5213531999d..f8ddc6713c011 100644 --- a/offload/plugins-nextgen/host/src/rtl.cpp +++ b/offload/plugins-nextgen/host/src/rtl.cpp @@ -320,6 +320,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy { "initDeviceInfoImpl not supported"); } + Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData, + AsyncInfoWrapperTy &AsyncInfo) override { + Callback(UserData); + return Plugin::success(); + }; + /// This plugin does not support the event API. Do nothing without failing. Error createEventImpl(void **EventPtrStorage) override { *EventPtrStorage = nullptr; diff --git a/offload/unittests/OffloadAPI/CMakeLists.txt b/offload/unittests/OffloadAPI/CMakeLists.txt index 8f0267eb39bdf..b25db7022e9d7 100644 --- a/offload/unittests/OffloadAPI/CMakeLists.txt +++ b/offload/unittests/OffloadAPI/CMakeLists.txt @@ -41,7 +41,8 @@ add_offload_unittest("queue" queue/olDestroyQueue.cpp queue/olGetQueueInfo.cpp queue/olGetQueueInfoSize.cpp - queue/olWaitEvents.cpp) + queue/olWaitEvents.cpp + queue/olLaunchHostFunction.cpp) add_offload_unittest("symbol" symbol/olGetSymbol.cpp diff --git a/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp new file mode 100644 index 0000000000000..aa86750f6adf9 --- /dev/null +++ b/offload/unittests/OffloadAPI/queue/olLaunchHostFunction.cpp @@ -0,0 +1,107 @@ +//===------- Offload API tests - olLaunchHostFunction ---------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "../common/Fixtures.hpp" +#include +#include +#include + +struct olLaunchHostFunctionTest : OffloadQueueTest {}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionTest); + +struct olLaunchHostFunctionKernelTest : OffloadKernelTest {}; +OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionKernelTest); + +TEST_P(olLaunchHostFunctionTest, Success) { + ASSERT_SUCCESS(olLaunchHostFunction(Queue, [](void *) {}, nullptr)); +} + +TEST_P(olLaunchHostFunctionTest, SuccessSequence) { + uint32_t Buff[16] = {1, 1}; + + for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) { + ASSERT_SUCCESS(olLaunchHostFunction( + Queue, + [](void *BuffPtr) { + uint32_t *AsU32 = reinterpret_cast(BuffPtr); + AsU32[0] = AsU32[-1] + AsU32[-2]; + }, + BuffPtr)); + } + + ASSERT_SUCCESS(olSyncQueue(Queue)); + + for (uint32_t i = 2; i < 16; i++) { + ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]); + } +} + +TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) { + // Verify that a host kernel can block execution - A host task is created that + // only resolves when Block is set to false. + ol_kernel_launch_size_args_t LaunchArgs; + LaunchArgs.Dimensions = 1; + LaunchArgs.GroupSize = {64, 1, 1}; + LaunchArgs.NumGroups = {1, 1, 1}; + LaunchArgs.DynSharedMemory = 0; + + ol_queue_handle_t Queue; + ASSERT_SUCCESS(olCreateQueue(Device, &Queue)); + + void *Mem; + ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED, + LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem)); + + uint32_t *Data = (uint32_t *)Mem; + for (uint32_t i = 0; i < 64; i++) { + Data[i] = 0; + } + + volatile bool Block = true; + ASSERT_SUCCESS(olLaunchHostFunction( + Queue, + [](void *Ptr) { + volatile bool *Block = + reinterpret_cast(reinterpret_cast(Ptr)); + + while (*Block) + std::this_thread::yield(); + }, + const_cast(&Block))); + + struct { + void *Mem; + } Args{Mem}; + ASSERT_SUCCESS( + olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), &LaunchArgs)); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + for (uint32_t i = 0; i < 64; i++) { + ASSERT_EQ(Data[i], 0); + } + + Block = false; + ASSERT_SUCCESS(olSyncQueue(Queue)); + + for (uint32_t i = 0; i < 64; i++) { + ASSERT_EQ(Data[i], i); + } + + ASSERT_SUCCESS(olDestroyQueue(Queue)); + ASSERT_SUCCESS(olMemFree(Mem)); +} + +TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER, + olLaunchHostFunction(Queue, nullptr, nullptr)); +} + +TEST_P(olLaunchHostFunctionTest, InvalidNullQueue) { + ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE, + olLaunchHostFunction(nullptr, [](void *) {}, nullptr)); +}