@@ -923,6 +923,10 @@ struct AMDGPUQueueTy {
923923// / devices. This class relies on signals to implement streams and define the
924924// / dependencies between asynchronous operations.
925925struct  AMDGPUStreamTy  {
926+ public: 
927+   // / Function pointer type for `pushHostCallback`
928+   using  HostFnType = void  (*)(void  *);
929+ 
926930private: 
927931  // / Utility struct holding arguments for async H2H memory copies.
928932  struct  MemcpyArgsTy  {
@@ -1084,18 +1088,19 @@ struct AMDGPUStreamTy {
10841088  // / Indicate to spread data transfers across all available SDMAs
10851089  bool  UseMultipleSdmaEngines;
10861090
1091+   struct  CallbackDataType  {
1092+     HostFnType UserFn;
1093+     void  *UserData;
1094+     AMDGPUSignalTy *OutputSignal;
1095+   };
10871096  // / Wrapper function for implementing host callbacks
1088-   static  void  CallbackWrapper (AMDGPUSignalTy *InputSignal,
1089-                               AMDGPUSignalTy *OutputSignal,
1090-                               void  (*Callback)(void  *), void *UserData) {
1091-     //  The wait call will not error in this context.
1092-     if  (InputSignal)
1093-       if  (auto  Err = InputSignal->wait ())
1094-         reportFatalInternalError (std::move (Err));
1095- 
1096-     Callback (UserData);
1097- 
1098-     OutputSignal->signal ();
1097+   static  bool  callbackWrapper ([[maybe_unused]] hsa_signal_value_t  Signal,
1098+                               void  *UserData) {
1099+     auto  CallbackData = reinterpret_cast <CallbackDataType *>(UserData);
1100+     CallbackData->UserFn (CallbackData->UserData );
1101+     CallbackData->OutputSignal ->signal ();
1102+     delete  CallbackData;
1103+     return  false ;
10991104  }
11001105
11011106  // / Return the current number of asynchronous operations on the stream.
@@ -1540,7 +1545,7 @@ struct AMDGPUStreamTy {
15401545                                   OutputSignal->get ());
15411546  }
15421547
1543-   Error pushHostCallback (void  (* Callback)( void  *) , void *UserData) {
1548+   Error pushHostCallback (HostFnType  Callback, void  *UserData) {
15441549    //  Retrieve an available signal for the operation's output.
15451550    AMDGPUSignalTy *OutputSignal = nullptr ;
15461551    if  (auto  Err = SignalManager.getResource (OutputSignal))
@@ -1556,12 +1561,21 @@ struct AMDGPUStreamTy {
15561561      InputSignal = consume (OutputSignal).second ;
15571562    }
15581563
1559-     //  "Leaking" the thread here is consistent with other work added to the 
1560-     //  queue. The input and output signals will remain valid until the output is 
1561-     //  signaled. 
1562-     std::thread (CallbackWrapper,  InputSignal, OutputSignal, Callback, UserData) 
1563-         . detach ( );
1564+     auto  *CallbackData =  new  CallbackDataType{Callback, UserData, OutputSignal}; 
1565+     if  (InputSignal && InputSignal-> load ()) { 
1566+        hsa_status_t  Status =  hsa_amd_signal_async_handler ( 
1567+            InputSignal-> get (), HSA_SIGNAL_CONDITION_EQ,  0 , callbackWrapper, 
1568+           CallbackData );
15641569
1570+       return  Plugin::check (Status, " error in hsa_amd_signal_async_handler: %s" 
1571+     }
1572+ 
1573+     //  No dependencies - schedule it now.
1574+     //  Using a seperate thread because this function should run asynchronously
1575+     //  and not block the main thread.
1576+     std::thread ([](void  *CallbackData) { callbackWrapper (0 , CallbackData); },
1577+                 CallbackData)
1578+         .detach ();
15651579    return  Plugin::success ();
15661580  }
15671581
@@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
27332747    return  Plugin::success ();
27342748  }
27352749
2736-   Error enqueueHostCallImpl (void  (* Callback)( void  *) , void *UserData,
2750+   Error enqueueHostCallImpl (AMDGPUStreamTy::HostFnType  Callback, void  *UserData,
27372751                            AsyncInfoWrapperTy &AsyncInfo) override  {
27382752    AMDGPUStreamTy *Stream = nullptr ;
27392753    if  (auto  Err = getStream (AsyncInfo, Stream))
0 commit comments