File tree Expand file tree Collapse file tree 2 files changed +40
-8
lines changed
Expand file tree Collapse file tree 2 files changed +40
-8
lines changed Original file line number Diff line number Diff line change @@ -76,18 +76,34 @@ void ConcreteAPI::syncStreamWithEvent(void* streamPtr, void* eventPtr) {
7676}
7777
7878namespace {
79- static void streamCallback (void * data) {
79+ void streamCallbackEpheremal (void * data) {
8080 auto * function = reinterpret_cast <std::function<void ()>*>(data);
8181 (*function)();
8282 delete function;
8383}
84+
85+ void streamCallbackPermanent (void * data) {
86+ auto * function = reinterpret_cast <std::function<void ()>*>(data);
87+ (*function)();
88+ }
8489} // namespace
8590
8691void ConcreteAPI::streamHostFunction (void * streamPtr, const std::function<void ()>& function) {
8792 cudaStream_t stream = static_cast <cudaStream_t>(streamPtr);
88- auto * functionData = new std::function<void ()>(function);
89- cudaLaunchHostFunc (stream, &streamCallback, functionData);
90- CHECK_ERR;
93+
94+ cudaStreamCaptureStatus status{};
95+ cudaStreamIsCapturing (stream, &status);
96+
97+ if (status != cudaStreamCaptureStatusInvalidated) {
98+ auto * functionData = new std::function<void ()>(function);
99+ if (status == cudaStreamCaptureStatusActive) {
100+ cudaLaunchHostFunc (stream, &streamCallbackPermanent, functionData);
101+ }
102+ else {
103+ cudaLaunchHostFunc (stream, &streamCallbackEpheremal, functionData);
104+ }
105+ CHECK_ERR;
106+ }
91107}
92108
93109namespace {
Original file line number Diff line number Diff line change @@ -76,18 +76,34 @@ void ConcreteAPI::syncStreamWithEvent(void* streamPtr, void* eventPtr) {
7676}
7777
7878namespace {
79- static void streamCallback (void * data) {
79+ void streamCallbackEpheremal (void * data) {
8080 auto * function = reinterpret_cast <std::function<void ()>*>(data);
8181 (*function)();
8282 delete function;
8383}
84+
85+ void streamCallbackPermanent (void * data) {
86+ auto * function = reinterpret_cast <std::function<void ()>*>(data);
87+ (*function)();
88+ }
8489} // namespace
8590
8691void ConcreteAPI::streamHostFunction (void * streamPtr, const std::function<void ()>& function) {
8792 hipStream_t stream = static_cast <hipStream_t>(streamPtr);
88- auto * functionData = new std::function<void ()>(function);
89- hipLaunchHostFunc (stream, streamCallback, functionData);
90- CHECK_ERR;
93+
94+ hipStreamCaptureStatus status{};
95+ hipStreamIsCapturing (stream, &status);
96+
97+ if (status != hipStreamCaptureStatusInvalidated) {
98+ auto * functionData = new std::function<void ()>(function);
99+ if (status == hipStreamCaptureStatusActive) {
100+ hipLaunchHostFunc (stream, &streamCallbackPermanent, functionData);
101+ }
102+ else {
103+ hipLaunchHostFunc (stream, &streamCallbackEpheremal, functionData);
104+ }
105+ CHECK_ERR;
106+ }
91107}
92108
93109namespace {
You can’t perform that action at this time.
0 commit comments