@@ -111,29 +111,28 @@ int ggml_cuda_get_device() {
111111
112112static cudaError_t ggml_cuda_device_malloc (void ** ptr, size_t size, int device) {
113113 ggml_cuda_set_device (device);
114- cudaError_t err ;
115- if (getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " ) != nullptr ) {
116- err = cudaMallocManaged (ptr, size) ;
117- # if defined(GGML_USE_HIP)
118- if (err == hipSuccess ) {
119- CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
114+ auto device_info = ggml_cuda_info (). devices [device] ;
115+ if (device_info. managed_memory ) {
116+ bool prefer_managed = device_info. integrated ;
117+ char * uma_optin = getenv ( " GGML_CUDA_ENABLE_UNIFIED_MEMORY " );
118+ if (uma_optin != nullptr ) {
119+ prefer_managed = std::string (uma_optin) == " 1 " ;
120120 }
121121
122- // fall back to cudaMalloc if not supported (e.g. on Windows)
123- if ( err == hipErrorNotSupported) {
124- static bool warned_unsupported = false ;
125- if (!warned_unsupported) {
126- GGML_LOG_WARN ( " hipMallocManaged unsupported, falling back to hipMalloc. \n " );
127- warned_unsupported = true ;
122+ if (prefer_managed) {
123+ cudaError_t err = cudaMallocManaged (ptr, size);
124+
125+ # if defined(GGML_USE_HIP)
126+ if (err == hipSuccess) {
127+ CUDA_CHECK ( cudaMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device)) ;
128128 }
129+ #endif // defined(GGML_USE_HIP)
129130
130- err = cudaMalloc (ptr, size) ;
131+ return err ;
131132 }
132- #endif // defined(GGML_USE_HIP)
133- } else {
134- err = cudaMalloc (ptr, size);
135133 }
136- return err;
134+
135+ return cudaMalloc (ptr, size);
137136}
138137
139138#if defined(GGML_USE_HIP)
@@ -233,10 +232,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
233232
234233 info.default_tensor_split [id] = total_vram;
235234 total_vram += prop.totalGlobalMem ;
236- info.devices [id].integrated = false ; // Temporarily disabled due to issues with corrupted output (e.g. #15034)
237- info.devices [id].nsm = prop.multiProcessorCount ;
238- info.devices [id].smpb = prop.sharedMemPerBlock ;
239- info.devices [id].warp_size = prop.warpSize ;
235+ info.devices [id].integrated = prop.integrated ;
236+ info.devices [id].managed_memory = prop.managedMemory ;
237+ info.devices [id].nsm = prop.multiProcessorCount ;
238+ info.devices [id].smpb = prop.sharedMemPerBlock ;
239+ info.devices [id].warp_size = prop.warpSize ;
240240#if defined(GGML_USE_HIP)
241241 info.devices [id].smpbo = prop.sharedMemPerBlock ;
242242
0 commit comments