@@ -927,6 +927,8 @@ struct AMDGPUStreamTy {
927927 AMDGPUSignalManagerTy *SignalManager;
928928 };
929929
930+ using AMDGPUStreamCallbackTy = Error(void *Data);
931+
930932 // / The stream is composed of N stream's slots. The struct below represents
931933 // / the fields of each slot. Each slot has a signal and an optional action
932934 // / function. When appending an HSA asynchronous operation to the stream, one
@@ -942,65 +944,82 @@ struct AMDGPUStreamTy {
942944 // / operation as input signal.
943945 AMDGPUSignalTy *Signal;
944946
945- // / The action that must be performed after the operation's completion. Set
947+ // / The actions that must be performed after the operation's completion. Set
946948 // / to nullptr when there is no action to perform.
947- Error (*ActionFunction)( void *) ;
949+ llvm::SmallVector<AMDGPUStreamCallbackTy *> Callbacks ;
948950
949951 // / Space for the action's arguments. A pointer to these arguments is passed
950952 // / to the action function. Notice the space of arguments is limited.
951- union {
953+ union ActionArgsTy {
952954 MemcpyArgsTy MemcpyArgs;
953955 ReleaseBufferArgsTy ReleaseBufferArgs;
954956 ReleaseSignalArgsTy ReleaseSignalArgs;
955- } ActionArgs;
957+ void *CallbackArgs;
958+ };
959+
960+ llvm::SmallVector<ActionArgsTy> ActionArgs;
956961
957962 // / Create an empty slot.
958- StreamSlotTy () : Signal(nullptr ), ActionFunction( nullptr ) {}
963+ StreamSlotTy () : Signal(nullptr ), Callbacks({}), ActionArgs({} ) {}
959964
960965 // / Schedule a host memory copy action on the slot.
961966 Error schedHostMemoryCopy (void *Dst, const void *Src, size_t Size) {
962- ActionFunction = memcpyAction;
963- ActionArgs.MemcpyArgs = MemcpyArgsTy{Dst, Src, Size};
967+ Callbacks. emplace_back ( memcpyAction) ;
968+ ActionArgs.emplace_back (). MemcpyArgs = MemcpyArgsTy{Dst, Src, Size};
964969 return Plugin::success ();
965970 }
966971
967972 // / Schedule a release buffer action on the slot.
968973 Error schedReleaseBuffer (void *Buffer, AMDGPUMemoryManagerTy &Manager) {
969- ActionFunction = releaseBufferAction;
970- ActionArgs.ReleaseBufferArgs = ReleaseBufferArgsTy{Buffer, &Manager};
974+ Callbacks.emplace_back (releaseBufferAction);
975+ ActionArgs.emplace_back ().ReleaseBufferArgs =
976+ ReleaseBufferArgsTy{Buffer, &Manager};
971977 return Plugin::success ();
972978 }
973979
974980 // / Schedule a signal release action on the slot.
975981 Error schedReleaseSignal (AMDGPUSignalTy *SignalToRelease,
976982 AMDGPUSignalManagerTy *SignalManager) {
977- ActionFunction = releaseSignalAction;
978- ActionArgs.ReleaseSignalArgs =
983+ Callbacks. emplace_back ( releaseSignalAction) ;
984+ ActionArgs.emplace_back (). ReleaseSignalArgs =
979985 ReleaseSignalArgsTy{SignalToRelease, SignalManager};
980986 return Plugin::success ();
981987 }
982988
989+ // / Register a callback to be called on compleition
990+ Error schedCallback (AMDGPUStreamCallbackTy *Func, void *Data) {
991+ Callbacks.emplace_back (Func);
992+ ActionArgs.emplace_back ().CallbackArgs = Data;
993+
994+ return Plugin::success ();
995+ }
996+
983997 // Perform the action if needed.
984998 Error performAction () {
985- if (!ActionFunction )
999+ if (Callbacks. empty () )
9861000 return Plugin::success ();
9871001
988- // Perform the action.
989- if (ActionFunction == memcpyAction) {
990- if (auto Err = memcpyAction (&ActionArgs))
991- return Err;
992- } else if (ActionFunction == releaseBufferAction) {
993- if (auto Err = releaseBufferAction (&ActionArgs))
994- return Err;
995- } else if (ActionFunction == releaseSignalAction) {
996- if (auto Err = releaseSignalAction (&ActionArgs))
997- return Err;
998- } else {
999- return Plugin::error (" Unknown action function!" );
1002+ assert (Callbacks.size () == ActionArgs.size () && " Size mismatch" );
1003+ for (auto [Callback, ActionArg] : llvm::zip (Callbacks, ActionArgs)) {
1004+ // Perform the action.
1005+ if (Callback == memcpyAction) {
1006+ if (auto Err = memcpyAction (&ActionArg))
1007+ return Err;
1008+ } else if (Callback == releaseBufferAction) {
1009+ if (auto Err = releaseBufferAction (&ActionArg))
1010+ return Err;
1011+ } else if (Callback == releaseSignalAction) {
1012+ if (auto Err = releaseSignalAction (&ActionArg))
1013+ return Err;
1014+ } else if (Callback) {
1015+ if (auto Err = Callback (ActionArg.CallbackArgs ))
1016+ return Err;
1017+ }
10001018 }
10011019
10021020 // Invalidate the action.
1003- ActionFunction = nullptr ;
1021+ Callbacks.clear ();
1022+ ActionArgs.clear ();
10041023
10051024 return Plugin::success ();
10061025 }
0 commit comments