@@ -62,7 +62,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
6262[[noreturn]]
6363void ggml_cuda_error (const char * stmt, const char * func, const char * file, int line, const char * msg) {
6464 int id = -1 ; // in case cudaGetDevice fails
65- cudaGetDevice (&id);
65+ ( void ) cudaGetDevice (&id);
6666
6767 GGML_LOG_ERROR (GGML_CUDA_NAME " error: %s\n " , msg);
6868 GGML_LOG_ERROR (" current device: %d, in function %s at %s:%d\n " , id, func, file, line);
@@ -152,7 +152,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
152152 for (int id = 0 ; id < info.device_count ; ++id) {
153153 int device_vmm = 0 ;
154154
155- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
155+ #if !defined(GGML_CUDA_NO_VMM)
156156 CUdevice device;
157157 CU_CHECK (cuDeviceGet (&device, id));
158158 CU_CHECK (cuDeviceGetAttribute (&device_vmm, CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED, device));
@@ -164,7 +164,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
164164 alloc_prop.location .id = id;
165165 CU_CHECK (cuMemGetAllocationGranularity (&info.devices [id].vmm_granularity , &alloc_prop, CU_MEM_ALLOC_GRANULARITY_RECOMMENDED));
166166 }
167- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
167+ #endif // !defined(GGML_CUDA_NO_VMM)
168168 info.devices [id].vmm = !!device_vmm;
169169
170170 cudaDeviceProp prop;
@@ -300,7 +300,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
300300};
301301
302302// pool with virtual memory
303- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
303+ #if !defined(GGML_CUDA_NO_VMM)
304304struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
305305 static const size_t CUDA_POOL_VMM_MAX_SIZE = 1ull << 35 ; // 32 GB
306306
@@ -309,6 +309,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
309309 size_t pool_used = 0 ;
310310 size_t pool_size = 0 ;
311311 size_t granularity;
312+ std::vector<std::pair<CUdeviceptr, size_t >> mappings;
312313
313314 explicit ggml_cuda_pool_vmm (int device) :
314315 device(device),
@@ -317,7 +318,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
317318
318319 ~ggml_cuda_pool_vmm () {
319320 if (pool_addr != 0 ) {
320- CU_CHECK (cuMemUnmap (pool_addr, pool_size));
321+ for (std::pair<CUdeviceptr, size_t >& mapping : mappings) {
322+ CU_CHECK (cuMemUnmap (mapping.first , mapping.second ));
323+ }
321324 CU_CHECK (cuMemAddressFree (pool_addr, CUDA_POOL_VMM_MAX_SIZE));
322325 }
323326 }
@@ -350,7 +353,9 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
350353 }
351354
352355 // map at the end of the pool
353- CU_CHECK (cuMemMap (pool_addr + pool_size, reserve_size, 0 , handle, 0 ));
356+ CUdeviceptr start_ptr = reinterpret_cast <CUdeviceptr>(reinterpret_cast <char *>(pool_addr) + pool_size);
357+ CU_CHECK (cuMemMap (start_ptr, reserve_size, 0 , handle, 0 ));
358+ mappings.push_back ({start_ptr, reserve_size});
354359
355360 // the memory allocation handle is no longer needed after mapping
356361 CU_CHECK (cuMemRelease (handle));
@@ -360,7 +365,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
360365 access.location .type = CU_MEM_LOCATION_TYPE_DEVICE;
361366 access.location .id = device;
362367 access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
363- CU_CHECK (cuMemSetAccess (pool_addr + pool_size, reserve_size, &access, 1 ));
368+ CU_CHECK (cuMemSetAccess (reinterpret_cast <CUdeviceptr>( reinterpret_cast < char *>( pool_addr) + pool_size) , reserve_size, &access, 1 ));
364369
365370 // add to the pool
366371 pool_size += reserve_size;
@@ -372,7 +377,7 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
372377
373378 GGML_ASSERT (pool_addr != 0 );
374379
375- void * ptr = (void *) (pool_addr + pool_used);
380+ void * ptr = (void *) (reinterpret_cast <CUdeviceptr>( reinterpret_cast < char *>( pool_addr) + pool_used) );
376381 *actual_size = size;
377382 pool_used += size;
378383
@@ -391,17 +396,17 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool {
391396 pool_used -= size;
392397
393398 // all deallocations must be in reverse order of the allocations
394- GGML_ASSERT (ptr == (void *) (pool_addr + pool_used));
399+ GGML_ASSERT (ptr == (void *) (reinterpret_cast < char *>( pool_addr) + pool_used));
395400 }
396401};
397- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
402+ #endif // !defined(GGML_CUDA_NO_VMM)
398403
399404std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device (int device) {
400- #if !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
405+ #if !defined(GGML_CUDA_NO_VMM)
401406 if (ggml_cuda_info ().devices [device].vmm ) {
402407 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_vmm (device));
403408 }
404- #endif // !defined(GGML_USE_HIP) && !defined( GGML_CUDA_NO_VMM)
409+ #endif // !defined(GGML_CUDA_NO_VMM)
405410 return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg (device));
406411}
407412
@@ -547,7 +552,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_buffer_type_alloc_buffer(ggml_bac
547552 cudaError_t err = ggml_cuda_device_malloc (&dev_ptr, size, buft_ctx->device );
548553 if (err != cudaSuccess) {
549554 // clear the error
550- cudaGetLastError ();
555+ ( void ) cudaGetLastError ();
551556 GGML_LOG_ERROR (" %s: allocating %.2f MiB on device %d: cudaMalloc failed: %s\n " , __func__, size / 1024.0 / 1024.0 , buft_ctx->device , cudaGetErrorString (err));
552557 return nullptr ;
553558 }
@@ -962,7 +967,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
962967 cudaError_t err = cudaMallocHost ((void **) &ptr, size);
963968 if (err != cudaSuccess) {
964969 // clear the error
965- cudaGetLastError ();
970+ ( void ) cudaGetLastError ();
966971 GGML_LOG_DEBUG (" %s: failed to allocate %.2f MiB of pinned memory: %s\n " , __func__,
967972 size / 1024.0 / 1024.0 , cudaGetErrorString (err));
968973 return nullptr ;
@@ -1197,15 +1202,15 @@ static void ggml_cuda_set_peer_access(const int n_tokens, int main_device) {
11971202 CUDA_CHECK (err);
11981203 } else {
11991204 // reset the error
1200- cudaGetLastError ();
1205+ ( void ) cudaGetLastError ();
12011206 }
12021207 } else {
12031208 cudaError_t err = cudaDeviceDisablePeerAccess (id_other);
12041209 if (err != cudaErrorPeerAccessNotEnabled) {
12051210 CUDA_CHECK (err);
12061211 } else {
12071212 // reset the error
1208- cudaGetLastError ();
1213+ ( void ) cudaGetLastError ();
12091214 }
12101215 }
12111216 }
@@ -2438,7 +2443,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
24382443 if (stat == cudaErrorInvalidDeviceFunction) {
24392444 // Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
24402445 // We don't need to update blas nodes, so clear error and move on.
2441- cudaGetLastError ();
2446+ ( void ) cudaGetLastError ();
24422447 } else {
24432448 GGML_ASSERT (stat == cudaSuccess);
24442449 }
@@ -2506,7 +2511,7 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
25062511
25072512 // The pre-existing graph exec cannot be updated due to violated constraints
25082513 // so instead clear error and re-instantiate
2509- cudaGetLastError ();
2514+ ( void ) cudaGetLastError ();
25102515 CUDA_CHECK (cudaGraphExecDestroy (cuda_ctx->cuda_graph ->instance ));
25112516 cuda_ctx->cuda_graph ->instance = nullptr ;
25122517 CUDA_CHECK (cudaGraphInstantiate (&cuda_ctx->cuda_graph ->instance , cuda_ctx->cuda_graph ->graph , NULL , NULL , 0 ));
@@ -2734,7 +2739,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
27342739 cudaError_t err = cudaHostRegister (buffer, size, cudaHostRegisterPortable | cudaHostRegisterReadOnly);
27352740 if (err != cudaSuccess) {
27362741 // clear the error
2737- cudaGetLastError ();
2742+ ( void ) cudaGetLastError ();
27382743
27392744 GGML_LOG_DEBUG (" %s: failed to register %.2f MiB of pinned memory: %s\n " , __func__,
27402745 size / 1024.0 / 1024.0 , cudaGetErrorString (err));
@@ -2754,7 +2759,7 @@ void ggml_backend_cuda_unregister_host_buffer(void * buffer) {
27542759 cudaError_t err = cudaHostUnregister (buffer);
27552760 if (err != cudaSuccess) {
27562761 // clear the error
2757- cudaGetLastError ();
2762+ ( void ) cudaGetLastError ();
27582763 }
27592764}
27602765
0 commit comments