Skip to content

Commit e433dfc

Browse files
authored
Merge pull request #54 from SeisSol/davschneller/hostfunc-graph
Fix CUDA/HIP host functions on graphs
2 parents 6d73b9c + a78f558 commit e433dfc

File tree

2 files changed

+40
-8
lines changed

2 files changed

+40
-8
lines changed

interfaces/cuda/Streams.cu

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,34 @@ void ConcreteAPI::syncStreamWithEvent(void* streamPtr, void* eventPtr) {
7676
}
7777

7878
namespace {
79-
static void streamCallback(void* data) {
79+
void streamCallbackEpheremal(void* data) {
8080
auto* function = reinterpret_cast<std::function<void()>*>(data);
8181
(*function)();
8282
delete function;
8383
}
84+
85+
void streamCallbackPermanent(void* data) {
86+
auto* function = reinterpret_cast<std::function<void()>*>(data);
87+
(*function)();
88+
}
8489
} // namespace
8590

8691
void ConcreteAPI::streamHostFunction(void* streamPtr, const std::function<void()>& function) {
8792
cudaStream_t stream = static_cast<cudaStream_t>(streamPtr);
88-
auto* functionData = new std::function<void()>(function);
89-
cudaLaunchHostFunc(stream, &streamCallback, functionData);
90-
CHECK_ERR;
93+
94+
cudaStreamCaptureStatus status{};
95+
cudaStreamIsCapturing(stream, &status);
96+
97+
if (status != cudaStreamCaptureStatusInvalidated) {
98+
auto* functionData = new std::function<void()>(function);
99+
if (status == cudaStreamCaptureStatusActive) {
100+
cudaLaunchHostFunc(stream, &streamCallbackPermanent, functionData);
101+
}
102+
else {
103+
cudaLaunchHostFunc(stream, &streamCallbackEpheremal, functionData);
104+
}
105+
CHECK_ERR;
106+
}
91107
}
92108

93109
namespace {

interfaces/hip/Streams.cpp

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -76,18 +76,34 @@ void ConcreteAPI::syncStreamWithEvent(void* streamPtr, void* eventPtr) {
7676
}
7777

7878
namespace {
79-
static void streamCallback(void* data) {
79+
void streamCallbackEpheremal(void* data) {
8080
auto* function = reinterpret_cast<std::function<void()>*>(data);
8181
(*function)();
8282
delete function;
8383
}
84+
85+
void streamCallbackPermanent(void* data) {
86+
auto* function = reinterpret_cast<std::function<void()>*>(data);
87+
(*function)();
88+
}
8489
} // namespace
8590

8691
void ConcreteAPI::streamHostFunction(void* streamPtr, const std::function<void()>& function) {
8792
hipStream_t stream = static_cast<hipStream_t>(streamPtr);
88-
auto* functionData = new std::function<void()>(function);
89-
hipLaunchHostFunc(stream, streamCallback, functionData);
90-
CHECK_ERR;
93+
94+
hipStreamCaptureStatus status{};
95+
hipStreamIsCapturing(stream, &status);
96+
97+
if (status != hipStreamCaptureStatusInvalidated) {
98+
auto* functionData = new std::function<void()>(function);
99+
if (status == hipStreamCaptureStatusActive) {
100+
hipLaunchHostFunc(stream, &streamCallbackPermanent, functionData);
101+
}
102+
else {
103+
hipLaunchHostFunc(stream, &streamCallbackEpheremal, functionData);
104+
}
105+
CHECK_ERR;
106+
}
91107
}
92108

93109
namespace {

0 commit comments

Comments
 (0)