@@ -1063,6 +1063,20 @@ struct AMDGPUStreamTy {
10631063 // / Indicate to spread data transfers across all available SDMAs
10641064 bool UseMultipleSdmaEngines;
10651065
1066+ // / Wrapper function for implementing host callbacks
1067+ static void CallbackWrapper (AMDGPUSignalTy *InputSignal,
1068+ AMDGPUSignalTy *OutputSignal,
1069+ void (*Callback)(void *), void *UserData) {
1070+ if (InputSignal)
1071+ if (auto Err = InputSignal->wait ())
1072+ // Wait shouldn't report an error
1073+ reportFatalInternalError (std::move (Err));
1074+
1075+ Callback (UserData);
1076+
1077+ OutputSignal->signal ();
1078+ }
1079+
10661080 // / Return the current number of asynchronous operations on the stream.
10671081 uint32_t size () const { return NextSlot; }
10681082
@@ -1495,6 +1509,31 @@ struct AMDGPUStreamTy {
14951509 OutputSignal->get ());
14961510 }
14971511
1512+ Error pushHostCallback (void (*Callback)(void *), void *UserData) {
1513+ // Retrieve an available signal for the operation's output.
1514+ AMDGPUSignalTy *OutputSignal = nullptr ;
1515+ if (auto Err = SignalManager.getResource (OutputSignal))
1516+ return Err;
1517+ OutputSignal->reset ();
1518+ OutputSignal->increaseUseCount ();
1519+
1520+ AMDGPUSignalTy *InputSignal;
1521+ {
1522+ std::lock_guard<std::mutex> Lock (Mutex);
1523+
1524+ // Consume stream slot and compute dependencies.
1525+ InputSignal = consume (OutputSignal).second ;
1526+ }
1527+
1528+ // "Leaking" the thread here is consistent with other work added to the
1529+ // queue. The input and output signals will remain valid until the output is
1530+ // signaled.
1531+ std::thread (CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
1532+ .detach ();
1533+
1534+ return Plugin::success ();
1535+ }
1536+
14981537 // / Synchronize with the stream. The current thread waits until all operations
14991538 // / are finalized and it performs the pending post actions (i.e., releasing
15001539 // / intermediate buffers).
@@ -2553,6 +2592,15 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
25532592 return Plugin::success ();
25542593 }
25552594
2595+ Error enqueueHostCallbackImpl (void (*Callback)(void *), void *UserData,
2596+ AsyncInfoWrapperTy &AsyncInfo) override {
2597+ AMDGPUStreamTy *Stream = nullptr ;
2598+ if (auto Err = getStream (AsyncInfo, Stream))
2599+ return Err;
2600+
2601+ return Stream->pushHostCallback (Callback, UserData);
2602+ };
2603+
25562604 // / Create an event.
25572605 Error createEventImpl (void **EventPtrStorage) override {
25582606 AMDGPUEventTy **Event = reinterpret_cast <AMDGPUEventTy **>(EventPtrStorage);
0 commit comments