@@ -402,107 +402,42 @@ def __init__(self):
402402 self .device_stats = {device : {'memory_used' : 0 , 'active_models' : 0 } for device in self .available_devices }
403403
404404 def _detect_devices (self ) -> List [str ]:
405- """Detect available compute devices including AMD GPUs via ROCm"""
406405 devices = ['cpu' ]
407-
408- # Check for CUDA (NVIDIA) GPUs
409406 if torch .cuda .is_available ():
410- backend = torch .cuda .get_device_properties (0 ).platform
411- if backend == 'ROCm' :
412- # AMD GPUs via ROCm
413- devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
414- logging .info ("Detected AMD GPU(s) using ROCm backend" )
415- else :
416- # NVIDIA GPUs
417- devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
418- logging .info ("Detected NVIDIA GPU(s)" )
419-
420- # Check for Apple M-series GPU
407+ devices .extend ([f'cuda:{ i } ' for i in range (torch .cuda .device_count ())])
421408 if torch .backends .mps .is_available ():
422409 devices .append ('mps' )
423- logging .info ("Detected Apple M-series GPU" )
424-
425410 return devices
426411
427412 def get_optimal_device (self , model_size : int = 0 ) -> str :
428- """Select the optimal device considering AMD GPU support"""
429413 if not self .available_devices :
430414 return 'cpu'
431415
432- # Get CUDA devices (both NVIDIA and AMD via ROCm)
416+ # Prefer CUDA devices if available
433417 cuda_devices = [d for d in self .available_devices if 'cuda' in d ]
434-
435418 if cuda_devices :
436- # Find device with most free memory
419+ # Find CUDA device with most free memory
437420 max_free_memory = 0
438421 optimal_device = cuda_devices [0 ]
439422
440- try :
441- for device in cuda_devices :
442- idx = int (device .split (':' )[1 ])
443- # Get memory info safely handling both NVIDIA and AMD
444- try :
445- total_memory = torch .cuda .get_device_properties (idx ).total_memory
446- used_memory = torch .cuda .memory_allocated (idx )
447- free_memory = total_memory - used_memory
448- except Exception as e :
449- logging .warning (f"Error getting memory info for device { device } : { e } " )
450- continue
451-
452- if free_memory > max_free_memory :
453- max_free_memory = free_memory
454- optimal_device = device
455-
456- logging .info (f"Selected optimal CUDA device: { optimal_device } with { max_free_memory / 1e9 :.2f} GB free memory" )
457- return optimal_device
458-
459- except Exception as e :
460- logging .error (f"Error selecting optimal CUDA device: { e } " )
461- # Fall back to first CUDA device if memory query fails
462- return cuda_devices [0 ]
423+ for device in cuda_devices :
424+ idx = int (device .split (':' )[1 ])
425+ free_memory = torch .cuda .get_device_properties (idx ).total_memory - torch .cuda .memory_allocated (idx )
426+ if free_memory > max_free_memory :
427+ max_free_memory = free_memory
428+ optimal_device = device
429+
430+ return optimal_device
463431
464432 # Fall back to MPS if available
465433 if 'mps' in self .available_devices :
466434 return 'mps'
467435
468- # Final fallback to CPU
469- logging .info ("No GPU detected, using CPU" )
470436 return 'cpu'
471437
472438 def track_device_usage (self , device : str , memory_delta : int ):
473- """Track memory usage for the device"""
474439 if device in self .device_stats :
475440 self .device_stats [device ]['memory_used' ] += memory_delta
476-
477- def get_device_info (self , device : str ) -> Dict [str , Any ]:
478- """Get detailed information about a device"""
479- info = {
480- 'type' : 'cpu' ,
481- 'memory_total' : None ,
482- 'memory_used' : None ,
483- 'memory_free' : None
484- }
485-
486- if 'cuda' in device :
487- try :
488- idx = int (device .split (':' )[1 ])
489- props = torch .cuda .get_device_properties (idx )
490- info .update ({
491- 'type' : 'gpu' ,
492- 'name' : props .name ,
493- 'backend' : 'ROCm' if hasattr (props , 'platform' ) and props .platform == 'ROCm' else 'CUDA' ,
494- 'compute_capability' : f"{ props .major } .{ props .minor } " ,
495- 'memory_total' : props .total_memory ,
496- 'memory_used' : torch .cuda .memory_allocated (idx ),
497- 'memory_free' : props .total_memory - torch .cuda .memory_allocated (idx )
498- })
499- except Exception as e :
500- logging .warning (f"Error getting device info for { device } : { e } " )
501-
502- elif device == 'mps' :
503- info ['type' ] = 'mps'
504-
505- return info
506441
507442class ModelManager :
508443 def __init__ (self , cache_manager : CacheManager , device_manager : DeviceManager ):
0 commit comments