Skip to content

Commit 30c7951

Browse files
authored
[Offload] olLaunchHostFunction (#152482)
Add an `olLaunchHostFunction` method that allows enqueueing host work to the stream.
1 parent 5985620 commit 30c7951

File tree

10 files changed

+233
-2
lines changed

10 files changed

+233
-2
lines changed

offload/liboffload/API/APIDefs.td

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,13 @@ class IsHandleType<string Type> {
3131
!ne(!find(Type, "_handle_t", !sub(!size(Type), 9)), -1));
3232
}
3333

34+
// Does the type end with '_cb_t'?
35+
class IsCallbackType<string Type> {
36+
// size("_cb_t") == 5
37+
bit ret = !if(!lt(!size(Type), 5), 0,
38+
!ne(!find(Type, "_cb_t", !sub(!size(Type), 5)), -1));
39+
}
40+
3441
// Does the type end with '*'?
3542
class IsPointerType<string Type> {
3643
bit ret = !ne(!find(Type, "*", !sub(!size(Type), 1)), -1);
@@ -58,6 +65,7 @@ class Param<string Type, string Name, string Desc, bits<3> Flags = 0> {
5865
TypeInfo type_info = TypeInfo<"", "">;
5966
bit IsHandle = IsHandleType<type>.ret;
6067
bit IsPointer = IsPointerType<type>.ret;
68+
bit IsCallback = IsCallbackType<type>.ret;
6169
}
6270

6371
// A parameter whose range is described by other parameters in the function.
@@ -81,7 +89,7 @@ class ShouldCheckHandle<Param P> {
8189
}
8290

8391
class ShouldCheckPointer<Param P> {
84-
bit ret = !and(P.IsPointer, !eq(!and(PARAM_OPTIONAL, P.flags), 0));
92+
bit ret = !and(!or(P.IsPointer, P.IsCallback), !eq(!and(PARAM_OPTIONAL, P.flags), 0));
8593
}
8694

8795
// For a list of returns that contains a specific return code, find and append

offload/liboffload/API/Queue.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,29 @@ def : Function {
108108
Return<"OL_ERRC_INVALID_QUEUE">
109109
];
110110
}
111+
112+
def : FptrTypedef {
113+
let name = "ol_host_function_cb_t";
114+
let desc = "Host function for use by `olLaunchHostFunction`.";
115+
let params = [
116+
Param<"void *", "UserData", "user specified data passed into `olLaunchHostFunction`.", PARAM_IN>,
117+
];
118+
let return = "void";
119+
}
120+
121+
def : Function {
122+
let name = "olLaunchHostFunction";
123+
let desc = "Enqueue a callback function on the host.";
124+
let details = [
125+
"The provided function will be called from the same process as the one that called `olLaunchHostFunction`.",
126+
"The callback will not run until all previous work submitted to the queue has completed.",
127+
"The callback must return before any work submitted to the queue after it is started.",
128+
"The callback must not call any liboffload API functions or any backend specific functions (such as Cuda or HSA library functions).",
129+
];
130+
let params = [
131+
Param<"ol_queue_handle_t", "Queue", "handle of the queue", PARAM_IN>,
132+
Param<"ol_host_function_cb_t", "Callback", "the callback function to call on the host", PARAM_IN>,
133+
Param<"void *", "UserData", "a pointer that will be passed verbatim to the callback function", PARAM_IN_OPTIONAL>,
134+
];
135+
let returns = [];
136+
}

offload/liboffload/src/OffloadImpl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -833,5 +833,12 @@ Error olGetSymbolInfoSize_impl(ol_symbol_handle_t Symbol,
833833
return olGetSymbolInfoImplDetail(Symbol, PropName, 0, nullptr, PropSizeRet);
834834
}
835835

836+
Error olLaunchHostFunction_impl(ol_queue_handle_t Queue,
837+
ol_host_function_cb_t Callback,
838+
void *UserData) {
839+
return Queue->Device->Device->enqueueHostCall(Callback, UserData,
840+
Queue->AsyncInfo);
841+
}
842+
836843
} // namespace offload
837844
} // namespace llvm

offload/plugins-nextgen/amdgpu/src/rtl.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy {
10631063
/// Indicate to spread data transfers across all available SDMAs
10641064
bool UseMultipleSdmaEngines;
10651065

1066+
/// Wrapper function for implementing host callbacks
1067+
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
1068+
AMDGPUSignalTy *OutputSignal,
1069+
void (*Callback)(void *), void *UserData) {
1070+
// The wait call will not error in this context.
1071+
if (InputSignal)
1072+
if (auto Err = InputSignal->wait())
1073+
reportFatalInternalError(std::move(Err));
1074+
1075+
Callback(UserData);
1076+
1077+
OutputSignal->signal();
1078+
}
1079+
10661080
/// Return the current number of asynchronous operations on the stream.
10671081
uint32_t size() const { return NextSlot; }
10681082

@@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy {
14951509
OutputSignal->get());
14961510
}
14971511

1512+
Error pushHostCallback(void (*Callback)(void *), void *UserData) {
1513+
// Retrieve an available signal for the operation's output.
1514+
AMDGPUSignalTy *OutputSignal = nullptr;
1515+
if (auto Err = SignalManager.getResource(OutputSignal))
1516+
return Err;
1517+
OutputSignal->reset();
1518+
OutputSignal->increaseUseCount();
1519+
1520+
AMDGPUSignalTy *InputSignal;
1521+
{
1522+
std::lock_guard<std::mutex> Lock(Mutex);
1523+
1524+
// Consume stream slot and compute dependencies.
1525+
InputSignal = consume(OutputSignal).second;
1526+
}
1527+
1528+
// "Leaking" the thread here is consistent with other work added to the
1529+
// queue. The input and output signals will remain valid until the output is
1530+
// signaled.
1531+
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
1532+
.detach();
1533+
1534+
return Plugin::success();
1535+
}
1536+
14981537
/// Synchronize with the stream. The current thread waits until all operations
14991538
/// are finalized and it performs the pending post actions (i.e., releasing
15001539
/// intermediate buffers).
@@ -2553,6 +2592,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
25532592
return Plugin::success();
25542593
}
25552594

2595+
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
2596+
AsyncInfoWrapperTy &AsyncInfo) override {
2597+
AMDGPUStreamTy *Stream = nullptr;
2598+
if (auto Err = getStream(AsyncInfo, Stream))
2599+
return Err;
2600+
2601+
return Stream->pushHostCallback(Callback, UserData);
2602+
};
2603+
25562604
/// Create an event.
25572605
Error createEventImpl(void **EventPtrStorage) override {
25582606
AMDGPUEventTy **Event = reinterpret_cast<AMDGPUEventTy **>(EventPtrStorage);

offload/plugins-nextgen/common/include/PluginInterface.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,12 @@ struct GenericDeviceTy : public DeviceAllocatorTy {
965965
Error initDeviceInfo(__tgt_device_info *DeviceInfo);
966966
virtual Error initDeviceInfoImpl(__tgt_device_info *DeviceInfo) = 0;
967967

968+
/// Enqueue a host call to AsyncInfo
969+
Error enqueueHostCall(void (*Callback)(void *), void *UserData,
970+
__tgt_async_info *AsyncInfo);
971+
virtual Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
972+
AsyncInfoWrapperTy &AsyncInfo) = 0;
973+
968974
/// Create an event.
969975
Error createEvent(void **EventPtrStorage);
970976
virtual Error createEventImpl(void **EventPtrStorage) = 0;

offload/plugins-nextgen/common/src/PluginInterface.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,15 @@ Error GenericDeviceTy::initAsyncInfo(__tgt_async_info **AsyncInfoPtr) {
15891589
return Err;
15901590
}
15911591

1592+
Error GenericDeviceTy::enqueueHostCall(void (*Callback)(void *), void *UserData,
1593+
__tgt_async_info *AsyncInfo) {
1594+
AsyncInfoWrapperTy AsyncInfoWrapper(*this, AsyncInfo);
1595+
1596+
auto Err = enqueueHostCallImpl(Callback, UserData, AsyncInfoWrapper);
1597+
AsyncInfoWrapper.finalize(Err);
1598+
return Err;
1599+
}
1600+
15921601
Error GenericDeviceTy::initDeviceInfo(__tgt_device_info *DeviceInfo) {
15931602
assert(DeviceInfo && "Invalid device info");
15941603

offload/plugins-nextgen/cuda/src/rtl.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -873,6 +873,19 @@ struct CUDADeviceTy : public GenericDeviceTy {
873873
return Plugin::success();
874874
}
875875

876+
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
877+
AsyncInfoWrapperTy &AsyncInfo) override {
878+
if (auto Err = setContext())
879+
return Err;
880+
881+
CUstream Stream;
882+
if (auto Err = getStream(AsyncInfo, Stream))
883+
return Err;
884+
885+
CUresult Res = cuLaunchHostFunc(Stream, Callback, UserData);
886+
return Plugin::check(Res, "error in cuStreamLaunchHostFunc: %s");
887+
};
888+
876889
/// Create an event.
877890
Error createEventImpl(void **EventPtrStorage) override {
878891
CUevent *Event = reinterpret_cast<CUevent *>(EventPtrStorage);

offload/plugins-nextgen/host/src/rtl.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,12 @@ struct GenELF64DeviceTy : public GenericDeviceTy {
320320
"initDeviceInfoImpl not supported");
321321
}
322322

323+
Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
324+
AsyncInfoWrapperTy &AsyncInfo) override {
325+
Callback(UserData);
326+
return Plugin::success();
327+
};
328+
323329
/// This plugin does not support the event API. Do nothing without failing.
324330
Error createEventImpl(void **EventPtrStorage) override {
325331
*EventPtrStorage = nullptr;

offload/unittests/OffloadAPI/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ add_offload_unittest("queue"
4141
queue/olDestroyQueue.cpp
4242
queue/olGetQueueInfo.cpp
4343
queue/olGetQueueInfoSize.cpp
44-
queue/olWaitEvents.cpp)
44+
queue/olWaitEvents.cpp
45+
queue/olLaunchHostFunction.cpp)
4546

4647
add_offload_unittest("symbol"
4748
symbol/olGetSymbol.cpp
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
//===------- Offload API tests - olLaunchHostFunction ---------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "../common/Fixtures.hpp"
10+
#include <OffloadAPI.h>
11+
#include <gtest/gtest.h>
12+
#include <thread>
13+
14+
struct olLaunchHostFunctionTest : OffloadQueueTest {};
15+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionTest);
16+
17+
struct olLaunchHostFunctionKernelTest : OffloadKernelTest {};
18+
OFFLOAD_TESTS_INSTANTIATE_DEVICE_FIXTURE(olLaunchHostFunctionKernelTest);
19+
20+
TEST_P(olLaunchHostFunctionTest, Success) {
21+
ASSERT_SUCCESS(olLaunchHostFunction(Queue, [](void *) {}, nullptr));
22+
}
23+
24+
TEST_P(olLaunchHostFunctionTest, SuccessSequence) {
25+
uint32_t Buff[16] = {1, 1};
26+
27+
for (auto BuffPtr = &Buff[2]; BuffPtr != &Buff[16]; BuffPtr++) {
28+
ASSERT_SUCCESS(olLaunchHostFunction(
29+
Queue,
30+
[](void *BuffPtr) {
31+
uint32_t *AsU32 = reinterpret_cast<uint32_t *>(BuffPtr);
32+
AsU32[0] = AsU32[-1] + AsU32[-2];
33+
},
34+
BuffPtr));
35+
}
36+
37+
ASSERT_SUCCESS(olSyncQueue(Queue));
38+
39+
for (uint32_t i = 2; i < 16; i++) {
40+
ASSERT_EQ(Buff[i], Buff[i - 1] + Buff[i - 2]);
41+
}
42+
}
43+
44+
TEST_P(olLaunchHostFunctionKernelTest, SuccessBlocking) {
45+
// Verify that a host kernel can block execution - A host task is created that
46+
// only resolves when Block is set to false.
47+
ol_kernel_launch_size_args_t LaunchArgs;
48+
LaunchArgs.Dimensions = 1;
49+
LaunchArgs.GroupSize = {64, 1, 1};
50+
LaunchArgs.NumGroups = {1, 1, 1};
51+
LaunchArgs.DynSharedMemory = 0;
52+
53+
ol_queue_handle_t Queue;
54+
ASSERT_SUCCESS(olCreateQueue(Device, &Queue));
55+
56+
void *Mem;
57+
ASSERT_SUCCESS(olMemAlloc(Device, OL_ALLOC_TYPE_MANAGED,
58+
LaunchArgs.GroupSize.x * sizeof(uint32_t), &Mem));
59+
60+
uint32_t *Data = (uint32_t *)Mem;
61+
for (uint32_t i = 0; i < 64; i++) {
62+
Data[i] = 0;
63+
}
64+
65+
volatile bool Block = true;
66+
ASSERT_SUCCESS(olLaunchHostFunction(
67+
Queue,
68+
[](void *Ptr) {
69+
volatile bool *Block =
70+
reinterpret_cast<volatile bool *>(reinterpret_cast<bool *>(Ptr));
71+
72+
while (*Block)
73+
std::this_thread::yield();
74+
},
75+
const_cast<bool *>(&Block)));
76+
77+
struct {
78+
void *Mem;
79+
} Args{Mem};
80+
ASSERT_SUCCESS(
81+
olLaunchKernel(Queue, Device, Kernel, &Args, sizeof(Args), &LaunchArgs));
82+
83+
std::this_thread::sleep_for(std::chrono::milliseconds(500));
84+
for (uint32_t i = 0; i < 64; i++) {
85+
ASSERT_EQ(Data[i], 0);
86+
}
87+
88+
Block = false;
89+
ASSERT_SUCCESS(olSyncQueue(Queue));
90+
91+
for (uint32_t i = 0; i < 64; i++) {
92+
ASSERT_EQ(Data[i], i);
93+
}
94+
95+
ASSERT_SUCCESS(olDestroyQueue(Queue));
96+
ASSERT_SUCCESS(olMemFree(Mem));
97+
}
98+
99+
TEST_P(olLaunchHostFunctionTest, InvalidNullCallback) {
100+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_POINTER,
101+
olLaunchHostFunction(Queue, nullptr, nullptr));
102+
}
103+
104+
TEST_P(olLaunchHostFunctionTest, InvalidNullQueue) {
105+
ASSERT_ERROR(OL_ERRC_INVALID_NULL_HANDLE,
106+
olLaunchHostFunction(nullptr, [](void *) {}, nullptr));
107+
}

0 commit comments

Comments
 (0)