Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 32 additions & 18 deletions offload/plugins-nextgen/amdgpu/src/rtl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -923,6 +923,10 @@ struct AMDGPUQueueTy {
/// devices. This class relies on signals to implement streams and define the
/// dependencies between asynchronous operations.
struct AMDGPUStreamTy {
public:
/// Function pointer type for `pushHostCallback`
using HostFnType = void (*)(void *);

private:
/// Utility struct holding arguments for async H2H memory copies.
struct MemcpyArgsTy {
Expand Down Expand Up @@ -1084,18 +1088,19 @@ struct AMDGPUStreamTy {
/// Indicate to spread data transfers across all available SDMAs
bool UseMultipleSdmaEngines;

struct CallbackDataType {
HostFnType UserFn;
void *UserData;
AMDGPUSignalTy *OutputSignal;
};
/// Wrapper function for implementing host callbacks
static void CallbackWrapper(AMDGPUSignalTy *InputSignal,
AMDGPUSignalTy *OutputSignal,
void (*Callback)(void *), void *UserData) {
// The wait call will not error in this context.
if (InputSignal)
if (auto Err = InputSignal->wait())
reportFatalInternalError(std::move(Err));

Callback(UserData);

OutputSignal->signal();
static bool callbackWrapper([[maybe_unused]] hsa_signal_value_t Signal,
void *UserData) {
auto CallbackData = reinterpret_cast<CallbackDataType *>(UserData);
CallbackData->UserFn(CallbackData->UserData);
CallbackData->OutputSignal->signal();
delete CallbackData;
return false;
}

/// Return the current number of asynchronous operations on the stream.
Expand Down Expand Up @@ -1540,7 +1545,7 @@ struct AMDGPUStreamTy {
OutputSignal->get());
}

Error pushHostCallback(void (*Callback)(void *), void *UserData) {
Error pushHostCallback(HostFnType Callback, void *UserData) {
// Retrieve an available signal for the operation's output.
AMDGPUSignalTy *OutputSignal = nullptr;
if (auto Err = SignalManager.getResource(OutputSignal))
Expand All @@ -1556,12 +1561,21 @@ struct AMDGPUStreamTy {
InputSignal = consume(OutputSignal).second;
}

// "Leaking" the thread here is consistent with other work added to the
// queue. The input and output signals will remain valid until the output is
// signaled.
std::thread(CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
.detach();
auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal};
if (InputSignal && InputSignal->load()) {
hsa_status_t Status = hsa_amd_signal_async_handler(
InputSignal->get(), HSA_SIGNAL_CONDITION_EQ, 0, callbackWrapper,
CallbackData);

return Plugin::check(Status, "error in hsa_amd_signal_async_handler: %s");
}

// No dependencies - schedule it now.
// Using a seperate thread because this function should run asynchronously
// and not block the main thread.
std::thread([](void *CallbackData) { callbackWrapper(0, CallbackData); },
CallbackData)
.detach();
return Plugin::success();
}

Expand Down Expand Up @@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
return Plugin::success();
}

Error enqueueHostCallImpl(void (*Callback)(void *), void *UserData,
Error enqueueHostCallImpl(AMDGPUStreamTy::HostFnType Callback, void *UserData,
AsyncInfoWrapperTy &AsyncInfo) override {
AMDGPUStreamTy *Stream = nullptr;
if (auto Err = getStream(AsyncInfo, Stream))
Expand Down
Loading