@@ -100,29 +100,28 @@ int ggml_cuda_get_device() {
100100
101101static cudaError_t ggml_cuda_device_malloc (void ** ptr, size_t size, int device) {
102102 ggml_cuda_set_device (device);
103- cudaError_t err ;
104- if (getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " ) != nullptr ) {
105- err = cudaMallocManaged (ptr, size) ;
106- # if defined(GGML_USE_HIP)
107- if (err == hipSuccess ) {
108- CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
103+ auto device_info = ggml_cuda_info (). devices [device] ;
104+ if (device_info. managed_memory ) {
105+ bool prefer_managed = device_info. integrated ;
106+ char * uma_optin = getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " );
107+ if (uma_optin != nullptr ) {
108+ prefer_managed = std::string (uma_optin) == " 1 " ;
109109 }
110110
111- // fall back to cudaMalloc if not supported (e.g. on Windows)
112- if ( err == hipErrorNotSupported) {
113- static bool warned_unsupported = false ;
114- if (!warned_unsupported) {
115- GGML_LOG_WARN ( " hipMallocManaged unsupported, falling back to hipMalloc. \n " );
116- warned_unsupported = true ;
111+ if (prefer_managed) {
112+ cudaError_t err = cudaMallocManaged (ptr, size);
113+
114+ # if defined(GGML_USE_HIP)
115+ if (err == hipSuccess) {
116+ CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
117117 }
118+ #endif // defined(GGML_USE_HIP)
118119
119- err = cudaMalloc (ptr, size) ;
120+ return err ;
120121 }
121- #endif // defined(GGML_USE_HIP)
122- } else {
123- err = cudaMalloc (ptr, size);
124122 }
125- return err;
123+
124+ return cudaMalloc (ptr, size);
126125}
127126
128127#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -244,10 +243,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
244243
245244 info.default_tensor_split [id] = total_vram;
246245 total_vram += prop.totalGlobalMem ;
247- info.devices [id].integrated = prop.integrated ;
248- info.devices [id].nsm = prop.multiProcessorCount ;
249- info.devices [id].smpb = prop.sharedMemPerBlock ;
250- info.devices [id].warp_size = prop.warpSize ;
246+ info.devices [id].integrated = prop.integrated ;
247+ info.devices [id].managed_memory = prop.managedMemory ;
248+ info.devices [id].nsm = prop.multiProcessorCount ;
249+ info.devices [id].smpb = prop.sharedMemPerBlock ;
250+ info.devices [id].warp_size = prop.warpSize ;
251251#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
252252 info.devices [id].smpbo = prop.sharedMemPerBlock ;
253253
0 commit comments