@@ -62,7 +62,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
6262[[noreturn]]
6363void ggml_cuda_error (const char * stmt, const char * func, const char * file, int line, const char * msg) {
6464 int id = -1 ; // in case cudaGetDevice fails
65- cudaGetDevice (&id);
65+ ( void ) cudaGetDevice (&id);
6666
6767 GGML_LOG_ERROR (GGML_CUDA_NAME " error: %s\n " , msg);
6868 GGML_LOG_ERROR (" current device: %d, in function %s at %s:%d\n " , id, func, file, line);
@@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
152152 for (int id = 0 ; id < info.device_count ; ++id) {
153153 int device_vmm = 0 ;
154154
155- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
155+ #if !defined(GGML_CUDA_NO_VMM)
156156 CUdevice device;
157157 CU_CHECK (cuDeviceGet (&device, id));
158158 CU_CHECK (cuDeviceGetAttribute (&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
164164 alloc_prop.location .id = id;
165165 CU_CHECK (cuMemGetAllocationGranularity (&info.devices [id].vmm_granularity , &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
166166 }
167- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
167+ #endif // !defined(GGML_CUDA_NO_VMM)
168168 info.devices [id].vmm = !!device_vmm;
169169
170170 cudaDeviceProp prop;
@@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
300300};
301301
302302// pool with virtual memory
303- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
303+ #if !defined(GGML_CUDA_NO_VMM)
304304struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
305305 static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35 ; // 32 GB
306306
@@ -309,6 +309,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
309309 size_t pool_used = 0 ;
310310 size_t pool_size = 0 ;
311311 size_t granularity;
312+ #if defined(GGML_USE_HIP)
313+ std::vector<std::pair<CUdeviceptr, size_t >> mappings;
314+ #endif
312315
313316 explicit ggml_cuda_pool_vmm (int device) :
314317 device(device),
@@ -317,7 +320,14 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
317320
318321 ~ggml_cuda_pool_vmm () {
319322 if (pool_addr != 0 ) {
323+ #if defined(GGML_USE_HIP)
324+ // Workaround for https://github.com/ROCm/ROCR-Runtime/issues/285
325+ for (std::pair<CUdeviceptr, size_t > & mapping : mappings) {
326+ CU_CHECK (cuMemUnmap (mapping.first , mapping.second ));
327+ }
328+ #else
320329 CU_CHECK (cuMemUnmap (pool_addr, pool_size));
330+ #endif
321331 CU_CHECK (cuMemAddressFree (pool_addr, CUDA_POOL_VMM_MAX_SIZE));
322332 }
323333 }
@@ -350,7 +360,11 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
350360 }
351361
352362 // map at the end of the pool
353- CU_CHECK (cuMemMap (pool_addr + pool_size, reserve_size, 0 , handle, 0 ));
363+ CUdeviceptr start_ptr = (CUdeviceptr)((char *)(pool_addr) + pool_size);
364+ CU_CHECK (cuMemMap (start_ptr, reserve_size, 0 , handle, 0 ));
365+ #if defined(GGML_USE_HIP)
366+ mappings.push_back ({start_ptr, reserve_size});
367+ #endif
354368
355369 // the memory allocation handle is no longer needed after mapping
356370 CU_CHECK (cuMemRelease (handle));
@@ -360,7 +374,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
360374 access.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
361375 access.location .id = device;
362376 access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
363- CU_CHECK (cuMemSetAccess (pool_addr + pool_size, reserve_size, &access, 1 ));
377+ CU_CHECK (cuMemSetAccess ((CUdeviceptr)(( char *)( pool_addr) + pool_size) , reserve_size, &access, 1 ));
364378
365379 // add to the pool
366380 pool_size += reserve_size;
@@ -372,7 +386,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
372386
373387 GGML_ASSERT (pool_addr != 0 );
374388
375- void * ptr = (void *) (pool_addr + pool_used);
389+ void * ptr = (void *) ((CUdeviceptr)(( char *)( pool_addr) + pool_used) );
376390 *actual_size = size;
377391 pool_used += size;
378392
@@ -391,17 +405,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
391405 pool_used -= size;
392406
393407 // all deallocations must be in reverse order of the allocations
394- GGML_ASSERT (ptr == (void *) (pool_addr + pool_used));
408+ GGML_ASSERT (ptr == (void *) (( char *)( pool_addr) + pool_used));
395409 }
396410};
397- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
411+ #endif // !defined(GGML_CUDA_NO_VMM)
398412
399413std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device (int device) {
400- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
414+ #if !defined(GGML_CUDA_NO_VMM)
401415 if (ggml_cuda_info ().devices [device].vmm ) {
402416 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm (device));
403417 }
404- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
418+ #endif // !defined(GGML_CUDA_NO_VMM)
405419 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg (device));
406420}
407421
@@ -547,7 +561,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
547561 cudaError_t err = ggml_cuda_device_malloc (&dev_ptr, size, buft_ctx->device );
548562 if (err != cudaSuccess) {
549563 // clear the error
550- cudaGetLastError ();
564+ ( void ) cudaGetLastError ();
551565 GGML_LOG_ERROR (" %s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n " , __func__, size / 1024.0 / 1024.0 , buft_ctx->device , cudaGetErrorString (err));
552566 return nullptr ;
553567 }
@@ -962,7 +976,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
962976 cudaError_t err = cudaMallocHost ((void **) &ptr, size);
963977 if (err != cudaSuccess) {
964978 // clear the error
965- cudaGetLastError ();
979+ ( void ) cudaGetLastError ();
966980 GGML_LOG_DEBUG (" %s: failed to allocate %.2f MiB of pinned memory: %s\n " , __func__,
967981 size / 1024.0 / 1024.0 , cudaGetErrorString (err));
968982 return nullptr ;
@@ -1197,15 +1211,15 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
11971211 CUDA_CHECK (err);
11981212 } else {
11991213 // reset the error
1200- cudaGetLastError ();
1214+ ( void ) cudaGetLastError ();
12011215 }
12021216 } else {
12031217 cudaError_t err = cudaDeviceDisablePeerAccess (id_other);
12041218 if (err != cudaErrorPeerAccessNotEnabled) {
12051219 CUDA_CHECK (err);
12061220 } else {
12071221 // reset the error
1208- cudaGetLastError ();
1222+ ( void ) cudaGetLastError ();
12091223 }
12101224 }
12111225 }
@@ -2438,7 +2452,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
24382452 if (stat == cudaErrorInvalidDeviceFunction) {
24392453 // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
24402454 // We don't need to update blas nodes, so clear error and move on.
2441- cudaGetLastError ();
2455+ ( void ) cudaGetLastError ();
24422456 } else {
24432457 GGML_ASSERT (stat == cudaSuccess);
24442458 }
@@ -2506,7 +2520,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
25062520
25072521 // The pre-existing graph exec cannot be updated due to violated constraints
25082522 // so instead clear error and re-instantiate
2509- cudaGetLastError ();
2523+ ( void ) cudaGetLastError ();
25102524 CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
25112525 cuda_ctx->cuda_graph ->instance = nullptr ;
25122526 CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
@@ -2734,7 +2748,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
27342748 cudaError_t err = cudaHostRegister (buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
27352749 if (err != cudaSuccess) {
27362750 // clear the error
2737- cudaGetLastError ();
2751+ ( void ) cudaGetLastError ();
27382752
27392753 GGML_LOG_DEBUG (" %s: failed to register %.2f MiB of pinned memory: %s\n " , __func__,
27402754 size / 1024.0 / 1024.0 , cudaGetErrorString (err));
@@ -2754,7 +2768,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
27542768 cudaError_t err = cudaHostUnregister (buffer);
27552769 if (err != cudaSuccess) {
27562770 // clear the error
2757- cudaGetLastError ();
2771+ ( void ) cudaGetLastError ();
27582772 }
27592773}
27602774
0 commit comments