@@ -98,32 +98,28 @@ int ggml_cuda_get_device() {
9898
9999static cudaError_t ggml_cuda_device_malloc (void ** ptr, size_t size, int device) {
100100 ggml_cuda_set_device (device);
101- cudaError_t err;
102- if (getenv (" GGML_CUDA_ENABLE_UNIFIED_MEMORY" ) != nullptr )
103- {
104- err = cudaMallocManaged (ptr, size);
105- #if defined(GGML_USE_HIP)
106- if (err == hipSuccess) {
107- CUDA_CHECK (cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device));
101+ auto device_info = ggml_cuda_info ().devices [device];
102+ if (device_info.managed_memory ) {
103+ bool prefer_managed = device_info.integrated ;
104+ char * uma_optin = getenv (" GGML_CUDA_ENABLE_UNIFIED_MEMORY" );
105+ if (uma_optin != nullptr ) {
106+ prefer_managed = std::string (uma_optin) == " 1" ;
108107 }
109108
110- // fall back to cudaMalloc if not supported (e.g. on Windows)
111- if ( err == hipErrorNotSupported) {
112- static bool warned_unsupported = false ;
113- if (!warned_unsupported) {
114- GGML_LOG_WARN ( " hipMallocManaged unsupported, falling back to hipMalloc. \n " );
115- warned_unsupported = true ;
109+ if (prefer_managed) {
110+ cudaError_t err = cudaMallocManaged (ptr, size);
111+
112+ # if defined(GGML_USE_HIP)
113+ if (err == hipSuccess) {
114+ CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
116115 }
116+ #endif // defined(GGML_USE_HIP)
117117
118- err = cudaMalloc (ptr, size) ;
118+ return err ;
119119 }
120- #endif // defined(GGML_USE_HIP)
121- }
122- else
123- {
124- err = cudaMalloc (ptr, size);
125120 }
126- return err;
121+
122+ return cudaMalloc (ptr, size);
127123}
128124
129125#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
@@ -245,10 +241,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
245241
246242 info.default_tensor_split [id] = total_vram;
247243 total_vram += prop.totalGlobalMem ;
248- info.devices [id].integrated = prop.integrated ;
249- info.devices [id].nsm = prop.multiProcessorCount ;
250- info.devices [id].smpb = prop.sharedMemPerBlock ;
251- info.devices [id].warp_size = prop.warpSize ;
244+ info.devices [id].integrated = prop.integrated ;
245+ info.devices [id].managed_memory = prop.managedMemory ;
246+ info.devices [id].nsm = prop.multiProcessorCount ;
247+ info.devices [id].smpb = prop.sharedMemPerBlock ;
248+ info.devices [id].warp_size = prop.warpSize ;
252249#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
253250 info.devices [id].smpbo = prop.sharedMemPerBlock ;
254251
0 commit comments