@@ -99,29 +99,28 @@ int ggml_cuda_get_device() {
9999
100100static cudaError_t ggml_cuda_device_malloc (void ** ptr, size_t size, int device) {
101101 ggml_cuda_set_device (device);
102- cudaError_t err ;
103- if (getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " ) != nullptr ) {
104- err = cudaMallocManaged (ptr, size) ;
105- # if defined(GGML_USE_HIP)
106- if (err == hipSuccess ) {
107- CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
102+ auto device_info = ggml_cuda_info (). devices [device] ;
103+ if (device_info. managed_memory ) {
104+ bool prefer_managed = device_info. integrated ;
105+ char * uma_optin = getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " );
106+ if (uma_optin != nullptr ) {
107+ prefer_managed = std::string (uma_optin) == " 1 " ;
108108 }
109109
110- // fall back to cudaMalloc if not supported (e.g. on Windows)
111- if ( err == hipErrorNotSupported) {
112- static bool warned_unsupported = false ;
113- if (!warned_unsupported) {
114- GGML_LOG_WARN ( " hipMallocManaged unsupported, falling back to hipMalloc. \n " );
115- warned_unsupported = true ;
110+ if (prefer_managed) {
111+ cudaError_t err = cudaMallocManaged (ptr, size);
112+
113+ # if defined(GGML_USE_HIP)
114+ if (err == hipSuccess) {
115+ CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
116116 }
117+ #endif // defined(GGML_USE_HIP)
117118
118- err = cudaMalloc (ptr, size) ;
119+ return err ;
119120 }
120- #endif // defined(GGML_USE_HIP)
121- } else {
122- err = cudaMalloc (ptr, size);
123121 }
124- return err;
122+
123+ return cudaMalloc (ptr, size);
125124}
126125
127126#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -243,10 +242,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
243242
244243 info.default_tensor_split [id] = total_vram;
245244 total_vram += prop.totalGlobalMem ;
246- info.devices [id].integrated = prop.integrated ;
247- info.devices [id].nsm = prop.multiProcessorCount ;
248- info.devices [id].smpb = prop.sharedMemPerBlock ;
249- info.devices [id].warp_size = prop.warpSize ;
245+ info.devices [id].integrated = prop.integrated ;
246+ info.devices [id].managed_memory = prop.managedMemory ;
247+ info.devices [id].nsm = prop.multiProcessorCount ;
248+ info.devices [id].smpb = prop.sharedMemPerBlock ;
249+ info.devices [id].warp_size = prop.warpSize ;
250250#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
251251 info.devices [id].smpbo = prop.sharedMemPerBlock ;
252252
0 commit comments