Skip to content

Commit 75a5d21

Browse files
committed
Revert "Update inference.py"
This reverts commit c0f0893.
1 parent c0f0893 commit 75a5d21

File tree

1 file changed

+11
-76
lines changed

1 file changed

+11
-76
lines changed

optillm/inference.py

Lines changed: 11 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -402,107 +402,42 @@ def __init__(self):
402402
self.device_stats = {device: {'memory_used': 0, 'active_models': 0} for device in self.available_devices}
403403

404404
def _detect_devices(self) -> List[str]:
405-
"""Detect available compute devices including AMD GPUs via ROCm"""
406405
devices = ['cpu']
407-
408-
# Check for CUDA (NVIDIA) GPUs
409406
if torch.cuda.is_available():
410-
backend = torch.cuda.get_device_properties(0).platform
411-
if backend == 'ROCm':
412-
# AMD GPUs via ROCm
413-
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
414-
logging.info("Detected AMD GPU(s) using ROCm backend")
415-
else:
416-
# NVIDIA GPUs
417-
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
418-
logging.info("Detected NVIDIA GPU(s)")
419-
420-
# Check for Apple M-series GPU
407+
devices.extend([f'cuda:{i}' for i in range(torch.cuda.device_count())])
421408
if torch.backends.mps.is_available():
422409
devices.append('mps')
423-
logging.info("Detected Apple M-series GPU")
424-
425410
return devices
426411

427412
def get_optimal_device(self, model_size: int = 0) -> str:
428-
"""Select the optimal device considering AMD GPU support"""
429413
if not self.available_devices:
430414
return 'cpu'
431415

432-
# Get CUDA devices (both NVIDIA and AMD via ROCm)
416+
# Prefer CUDA devices if available
433417
cuda_devices = [d for d in self.available_devices if 'cuda' in d]
434-
435418
if cuda_devices:
436-
# Find device with most free memory
419+
# Find CUDA device with most free memory
437420
max_free_memory = 0
438421
optimal_device = cuda_devices[0]
439422

440-
try:
441-
for device in cuda_devices:
442-
idx = int(device.split(':')[1])
443-
# Get memory info safely handling both NVIDIA and AMD
444-
try:
445-
total_memory = torch.cuda.get_device_properties(idx).total_memory
446-
used_memory = torch.cuda.memory_allocated(idx)
447-
free_memory = total_memory - used_memory
448-
except Exception as e:
449-
logging.warning(f"Error getting memory info for device {device}: {e}")
450-
continue
451-
452-
if free_memory > max_free_memory:
453-
max_free_memory = free_memory
454-
optimal_device = device
455-
456-
logging.info(f"Selected optimal CUDA device: {optimal_device} with {max_free_memory/1e9:.2f}GB free memory")
457-
return optimal_device
458-
459-
except Exception as e:
460-
logging.error(f"Error selecting optimal CUDA device: {e}")
461-
# Fall back to first CUDA device if memory query fails
462-
return cuda_devices[0]
423+
for device in cuda_devices:
424+
idx = int(device.split(':')[1])
425+
free_memory = torch.cuda.get_device_properties(idx).total_memory - torch.cuda.memory_allocated(idx)
426+
if free_memory > max_free_memory:
427+
max_free_memory = free_memory
428+
optimal_device = device
429+
430+
return optimal_device
463431

464432
# Fall back to MPS if available
465433
if 'mps' in self.available_devices:
466434
return 'mps'
467435

468-
# Final fallback to CPU
469-
logging.info("No GPU detected, using CPU")
470436
return 'cpu'
471437

472438
def track_device_usage(self, device: str, memory_delta: int):
473-
"""Track memory usage for the device"""
474439
if device in self.device_stats:
475440
self.device_stats[device]['memory_used'] += memory_delta
476-
477-
def get_device_info(self, device: str) -> Dict[str, Any]:
478-
"""Get detailed information about a device"""
479-
info = {
480-
'type': 'cpu',
481-
'memory_total': None,
482-
'memory_used': None,
483-
'memory_free': None
484-
}
485-
486-
if 'cuda' in device:
487-
try:
488-
idx = int(device.split(':')[1])
489-
props = torch.cuda.get_device_properties(idx)
490-
info.update({
491-
'type': 'gpu',
492-
'name': props.name,
493-
'backend': 'ROCm' if hasattr(props, 'platform') and props.platform == 'ROCm' else 'CUDA',
494-
'compute_capability': f"{props.major}.{props.minor}",
495-
'memory_total': props.total_memory,
496-
'memory_used': torch.cuda.memory_allocated(idx),
497-
'memory_free': props.total_memory - torch.cuda.memory_allocated(idx)
498-
})
499-
except Exception as e:
500-
logging.warning(f"Error getting device info for {device}: {e}")
501-
502-
elif device == 'mps':
503-
info['type'] = 'mps'
504-
505-
return info
506441

507442
class ModelManager:
508443
def __init__(self, cache_manager: CacheManager, device_manager: DeviceManager):

0 commit comments

Comments
 (0)