@@ -67,9 +67,13 @@ hipError_t hipDeviceGetMemPool(hipMemPool_t* mem_pool, int device) {
6767// ================================================================================================
6868hipError_t hipMallocAsync (void ** dev_ptr, size_t size, hipStream_t stream) {
6969 HIP_INIT_API (hipMallocAsync, dev_ptr, size, stream);
70- if ((dev_ptr == nullptr ) || (size == 0 ) || ( !hip::isValid (stream))) {
70+ if ((dev_ptr == nullptr ) || (!hip::isValid (stream))) {
7171 HIP_RETURN (hipErrorInvalidValue);
7272 }
73+ if (size == 0 ) {
74+ *dev_ptr = nullptr ;
75+ HIP_RETURN (hipSuccess);
76+ }
7377 auto hip_stream = (stream == nullptr ) ? hip::getCurrentDevice ()->NullStream () :
7478 reinterpret_cast <hip::Stream*>(stream);
7579 auto device = hip_stream->GetDevice ();
@@ -235,9 +239,13 @@ hipError_t hipMallocFromPoolAsync(
235239 hipMemPool_t mem_pool,
236240 hipStream_t stream) {
237241 HIP_INIT_API (hipMallocFromPoolAsync, dev_ptr, size, mem_pool, stream);
238- if ((dev_ptr == nullptr ) || (size == 0 ) || ( mem_pool == nullptr ) || (!hip::isValid (stream))) {
242+ if ((dev_ptr == nullptr ) || (mem_pool == nullptr ) || (!hip::isValid (stream))) {
239243 HIP_RETURN (hipErrorInvalidValue);
240244 }
245+ if (size == 0 ) {
246+ *dev_ptr = nullptr ;
247+ HIP_RETURN (hipSuccess);
248+ }
241249 STREAM_CAPTURE (hipMallocAsync, stream, mem_pool, size, dev_ptr);
242250
243251 auto mpool = reinterpret_cast <hip::MemoryPool*>(mem_pool);
0 commit comments