@@ -92,11 +92,19 @@ hipError_t hipMallocAsync(void** dev_ptr, size_t size, hipStream_t stream) {
9292 *dev_ptr = nullptr ;
9393 HIP_RETURN (hipSuccess);
9494 }
95+ hip::Stream* s = reinterpret_cast <hip::Stream*>(stream);
9596 auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ?
96- hip::getCurrentDevice ()->NullStream () : reinterpret_cast <hip::Stream*>(stream) ;
97+ hip::getCurrentDevice ()->NullStream () : s ;
9798 auto device = hip_stream->GetDevice ();
9899 auto mem_pool = device->GetCurrentMemoryPool ();
99100
101+ // Return error if any stream other than the current stream is in capture mode
102+ if (device->StreamCaptureBlocking ()) {
103+ if (s->GetCaptureStatus () != hipStreamCaptureStatusActive) {
104+ return hipErrorStreamCaptureUnsupported;
105+ }
106+ }
107+
100108 STREAM_CAPTURE (hipMallocAsync, stream, reinterpret_cast <hipMemPool_t>(mem_pool), size, dev_ptr);
101109
102110 *dev_ptr = mem_pool->AllocateMemory (size, hip_stream);
@@ -138,17 +146,28 @@ class FreeAsyncCommand : public amd::Command {
138146// ================================================================================================
139147hipError_t hipFreeAsync (void * dev_ptr, hipStream_t stream) {
140148 HIP_INIT_API (hipFreeAsync, dev_ptr, stream);
141- if (dev_ptr == nullptr ) {
142- HIP_RETURN (hipErrorInvalidValue);
143- }
149+
144150 if (!hip::isValid (stream)) {
145151 HIP_RETURN (hipErrorInvalidHandle);
146152 }
147153
148- STREAM_CAPTURE (hipFreeAsync, stream, dev_ptr);
149-
154+ hip::Stream* s = reinterpret_cast <hip::Stream*>(stream);
150155 auto hip_stream = (stream == nullptr || stream == hipStreamLegacy) ?
151- hip::getCurrentDevice ()->NullStream (): reinterpret_cast <hip::Stream*>(stream);
156+ hip::getCurrentDevice ()->NullStream (): s;
157+
158+ auto device = hip_stream->GetDevice ();
159+ // Return error if any stream other than the current stream is in capture mode
160+ if (device->StreamCaptureBlocking ()) {
161+ if (s->GetCaptureStatus () != hipStreamCaptureStatusActive) {
162+ return hipErrorStreamCaptureUnsupported;
163+ }
164+ }
165+
166+ if (dev_ptr == nullptr ) {
167+ HIP_RETURN (hipErrorInvalidValue);
168+ }
169+
170+ STREAM_CAPTURE (hipFreeAsync, stream, dev_ptr);
152171
153172 hip::Event* event = nullptr ;
154173 bool graph_in_use = false ;
0 commit comments