@@ -923,6 +923,10 @@ struct AMDGPUQueueTy {
923923// / devices. This class relies on signals to implement streams and define the
924924// / dependencies between asynchronous operations.
925925struct AMDGPUStreamTy {
926+ public:
927+ // / Function pointer type for `pushHostCallback`
928+ using HostFnType = void (*)(void *);
929+
926930private:
927931 // / Utility struct holding arguments for async H2H memory copies.
928932 struct MemcpyArgsTy {
@@ -1084,18 +1088,19 @@ struct AMDGPUStreamTy {
10841088 // / Indicate to spread data transfers across all available SDMAs
10851089 bool UseMultipleSdmaEngines;
10861090
1091+ struct CallbackDataType {
1092+ HostFnType UserFn;
1093+ void *UserData;
1094+ AMDGPUSignalTy *OutputSignal;
1095+ };
10871096 // / Wrapper function for implementing host callbacks
1088- static void CallbackWrapper (AMDGPUSignalTy *InputSignal,
1089- AMDGPUSignalTy *OutputSignal,
1090- void (*Callback)(void *), void *UserData) {
1091- // The wait call will not error in this context.
1092- if (InputSignal)
1093- if (auto Err = InputSignal->wait ())
1094- reportFatalInternalError (std::move (Err));
1095-
1096- Callback (UserData);
1097-
1098- OutputSignal->signal ();
1097+ static bool callbackWrapper ([[maybe_unused]] hsa_signal_value_t Signal,
1098+ void *UserData) {
1099+ auto CallbackData = reinterpret_cast <CallbackDataType *>(UserData);
1100+ CallbackData->UserFn (CallbackData->UserData );
1101+ CallbackData->OutputSignal ->signal ();
1102+ delete CallbackData;
1103+ return false ;
10991104 }
11001105
11011106 // / Return the current number of asynchronous operations on the stream.
@@ -1540,7 +1545,7 @@ struct AMDGPUStreamTy {
15401545 OutputSignal->get ());
15411546 }
15421547
1543- Error pushHostCallback (void (* Callback)( void *) , void *UserData) {
1548+ Error pushHostCallback (HostFnType Callback, void *UserData) {
15441549 // Retrieve an available signal for the operation's output.
15451550 AMDGPUSignalTy *OutputSignal = nullptr ;
15461551 if (auto Err = SignalManager.getResource (OutputSignal))
@@ -1556,12 +1561,21 @@ struct AMDGPUStreamTy {
15561561 InputSignal = consume (OutputSignal).second ;
15571562 }
15581563
1559- // "Leaking" the thread here is consistent with other work added to the
1560- // queue. The input and output signals will remain valid until the output is
1561- // signaled.
1562- std::thread (CallbackWrapper, InputSignal, OutputSignal, Callback, UserData)
1563- . detach ( );
1564+ auto *CallbackData = new CallbackDataType{Callback, UserData, OutputSignal};
1565+ if (InputSignal && InputSignal-> load ()) {
1566+ hsa_status_t Status = hsa_amd_signal_async_handler (
1567+ InputSignal-> get (), HSA_SIGNAL_CONDITION_EQ, 0 , callbackWrapper,
1568+ CallbackData );
15641569
1570+ return Plugin::check (Status, " error in hsa_amd_signal_async_handler: %s" );
1571+ }
1572+
1573+ // No dependencies - schedule it now.
1574+ // Using a seperate thread because this function should run asynchronously
1575+ // and not block the main thread.
1576+ std::thread ([](void *CallbackData) { callbackWrapper (0 , CallbackData); },
1577+ CallbackData)
1578+ .detach ();
15651579 return Plugin::success ();
15661580 }
15671581
@@ -2733,7 +2747,7 @@ struct AMDGPUDeviceTy : public GenericDeviceTy, AMDGenericDeviceTy {
27332747 return Plugin::success ();
27342748 }
27352749
2736- Error enqueueHostCallImpl (void (* Callback)( void *) , void *UserData,
2750+ Error enqueueHostCallImpl (AMDGPUStreamTy::HostFnType Callback, void *UserData,
27372751 AsyncInfoWrapperTy &AsyncInfo) override {
27382752 AMDGPUStreamTy *Stream = nullptr ;
27392753 if (auto Err = getStream (AsyncInfo, Stream))
0 commit comments