diff --git a/offload/plugins-nextgen/amdgpu/src/rtl.cpp b/offload/plugins-nextgen/amdgpu/src/rtl.cpp index f0cc0c2e4d08e..bdb33d4f4ab27 100644 --- a/offload/plugins-nextgen/amdgpu/src/rtl.cpp +++ b/offload/plugins-nextgen/amdgpu/src/rtl.cpp @@ -927,6 +927,8 @@ struct AMDGPUStreamTy { AMDGPUSignalManagerTy *SignalManager; }; + using AMDGPUStreamCallbackTy = Error(void *Data); + /// The stream is composed of N stream's slots. The struct below represents /// the fields of each slot. Each slot has a signal and an optional action /// function. When appending an HSA asynchronous operation to the stream, one @@ -942,65 +944,82 @@ struct AMDGPUStreamTy { /// operation as input signal. AMDGPUSignalTy *Signal; - /// The action that must be performed after the operation's completion. Set + /// The actions that must be performed after the operation's completion. Set /// to nullptr when there is no action to perform. - Error (*ActionFunction)(void *); + llvm::SmallVector Callbacks; /// Space for the action's arguments. A pointer to these arguments is passed /// to the action function. Notice the space of arguments is limited. - union { + union ActionArgsTy { MemcpyArgsTy MemcpyArgs; ReleaseBufferArgsTy ReleaseBufferArgs; ReleaseSignalArgsTy ReleaseSignalArgs; - } ActionArgs; + void *CallbackArgs; + }; + + llvm::SmallVector ActionArgs; /// Create an empty slot. - StreamSlotTy() : Signal(nullptr), ActionFunction(nullptr) {} + StreamSlotTy() : Signal(nullptr), Callbacks({}), ActionArgs({}) {} /// Schedule a host memory copy action on the slot. Error schedHostMemoryCopy(void *Dst, const void *Src, size_t Size) { - ActionFunction = memcpyAction; - ActionArgs.MemcpyArgs = MemcpyArgsTy{Dst, Src, Size}; + Callbacks.emplace_back(memcpyAction); + ActionArgs.emplace_back().MemcpyArgs = MemcpyArgsTy{Dst, Src, Size}; return Plugin::success(); } /// Schedule a release buffer action on the slot. Error schedReleaseBuffer(void *Buffer, AMDGPUMemoryManagerTy &Manager) { - ActionFunction = releaseBufferAction; - ActionArgs.ReleaseBufferArgs = ReleaseBufferArgsTy{Buffer, &Manager}; + Callbacks.emplace_back(releaseBufferAction); + ActionArgs.emplace_back().ReleaseBufferArgs = + ReleaseBufferArgsTy{Buffer, &Manager}; return Plugin::success(); } /// Schedule a signal release action on the slot. Error schedReleaseSignal(AMDGPUSignalTy *SignalToRelease, AMDGPUSignalManagerTy *SignalManager) { - ActionFunction = releaseSignalAction; - ActionArgs.ReleaseSignalArgs = + Callbacks.emplace_back(releaseSignalAction); + ActionArgs.emplace_back().ReleaseSignalArgs = ReleaseSignalArgsTy{SignalToRelease, SignalManager}; return Plugin::success(); } + /// Register a callback to be called on compleition + Error schedCallback(AMDGPUStreamCallbackTy *Func, void *Data) { + Callbacks.emplace_back(Func); + ActionArgs.emplace_back().CallbackArgs = Data; + + return Plugin::success(); + } + // Perform the action if needed. Error performAction() { - if (!ActionFunction) + if (Callbacks.empty()) return Plugin::success(); - // Perform the action. - if (ActionFunction == memcpyAction) { - if (auto Err = memcpyAction(&ActionArgs)) - return Err; - } else if (ActionFunction == releaseBufferAction) { - if (auto Err = releaseBufferAction(&ActionArgs)) - return Err; - } else if (ActionFunction == releaseSignalAction) { - if (auto Err = releaseSignalAction(&ActionArgs)) - return Err; - } else { - return Plugin::error("Unknown action function!"); + assert(Callbacks.size() == ActionArgs.size() && "Size mismatch"); + for (auto [Callback, ActionArg] : llvm::zip(Callbacks, ActionArgs)) { + // Perform the action. + if (Callback == memcpyAction) { + if (auto Err = memcpyAction(&ActionArg)) + return Err; + } else if (Callback == releaseBufferAction) { + if (auto Err = releaseBufferAction(&ActionArg)) + return Err; + } else if (Callback == releaseSignalAction) { + if (auto Err = releaseSignalAction(&ActionArg)) + return Err; + } else if (Callback) { + if (auto Err = Callback(ActionArg.CallbackArgs)) + return Err; + } } // Invalidate the action. - ActionFunction = nullptr; + Callbacks.clear(); + ActionArgs.clear(); return Plugin::success(); }