@@ -110,29 +110,28 @@ int ggml_cuda_get_device() {
110110
111111static cudaError_t ggml_cuda_device_malloc (void ** ptr, size_t size, int device) {
112112 ggml_cuda_set_device (device);
113- cudaError_t err ;
114- if (getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " ) != nullptr ) {
115- err = cudaMallocManaged (ptr, size) ;
116- # if defined(GGML_USE_HIP)
117- if (err == hipSuccess ) {
118- CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
113+ auto device_info = ggml_cuda_info (). devices [device] ;
114+ if (device_info. managed_memory ) {
115+ bool prefer_managed = device_info. integrated ;
116+ char * uma_optin = getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " );
117+ if (uma_optin != nullptr ) {
118+ prefer_managed = std::string (uma_optin) == " 1 " ;
119119 }
120120
121- // fall back to cudaMalloc if not supported (e.g. on Windows)
122- if ( err == hipErrorNotSupported) {
123- static bool warned_unsupported = false ;
124- if (!warned_unsupported) {
125- GGML_LOG_WARN ( " hipMallocManaged unsupported, falling back to hipMalloc. \n " );
126- warned_unsupported = true ;
121+ if (prefer_managed) {
122+ cudaError_t err = cudaMallocManaged (ptr, size);
123+
124+ # if defined(GGML_USE_HIP)
125+ if (err == hipSuccess) {
126+ CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
127127 }
128+ #endif // defined(GGML_USE_HIP)
128129
129- err = cudaMalloc (ptr, size) ;
130+ return err ;
130131 }
131- #endif // defined(GGML_USE_HIP)
132- } else {
133- err = cudaMalloc (ptr, size);
134132 }
135- return err;
133+
134+ return cudaMalloc (ptr, size);
136135}
137136
138137#if defined(GGML_USE_HIP)
@@ -232,10 +231,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
232231
233232 info.default_tensor_split [id] = total_vram;
234233 total_vram += prop.totalGlobalMem ;
235- info.devices [id].integrated = false ; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
236- info.devices [id].nsm = prop.multiProcessorCount ;
237- info.devices [id].smpb = prop.sharedMemPerBlock ;
238- info.devices [id].warp_size = prop.warpSize ;
234+ info.devices [id].integrated = prop.integrated ;
235+ info.devices [id].managed_memory = prop.managedMemory ;
236+ info.devices [id].nsm = prop.multiProcessorCount ;
237+ info.devices [id].smpb = prop.sharedMemPerBlock ;
238+ info.devices [id].warp_size = prop.warpSize ;
239239#if defined(GGML_USE_HIP)
240240 info.devices [id].smpbo = prop.sharedMemPerBlock ;
241241
0 commit comments