diff --git a/BackendBench/backends.py b/BackendBench/backends.py index 1d2012a..152855c 100644 --- a/BackendBench/backends.py +++ b/BackendBench/backends.py @@ -2,6 +2,13 @@ import importlib.util from typing import Dict, Callable, List +# Import VLLM backend if available +try: + from .vllm_backend import VLLMBackend + VLLM_AVAILABLE = True +except ImportError: + VLLM_AVAILABLE = False + class Backend: def __init__(self, name): diff --git a/BackendBench/distributed_workers.py b/BackendBench/distributed_workers.py new file mode 100644 index 0000000..241dce3 --- /dev/null +++ b/BackendBench/distributed_workers.py @@ -0,0 +1,486 @@ +import os +import time +import asyncio +import multiprocessing as mp +import threading +from typing import Dict, List, Optional, Any, Tuple +from dataclasses import dataclass +import hashlib +import json +import redis +import torch +import traceback + +# Set multiprocessing start method to 'spawn' for CUDA compatibility +# This must be done before any multiprocessing operations +if mp.get_start_method(allow_none=True) != 'spawn': + mp.set_start_method('spawn', force=True) + +from .vllm_backend import KernelStore, KernelResult +from .backends import LLMBackend + + +@dataclass +class EvaluationTask: + """Task for kernel evaluation""" + operation_name: str + kernel_code: str + kernel_hash: str + test_cases: List[Any] + gpu_id: int + worker_id: str + + +@dataclass +class WorkerConfig: + """Configuration for distributed workers""" + redis_host: str = "localhost" + redis_port: int = 6379 + vllm_model_path: str = "codellama/CodeLlama-7b-Instruct-hf" + generation_gpus: List[int] = None # GPUs for VLLM generation + evaluation_gpus: List[int] = None # GPUs for kernel evaluation + max_workers: int = 8 + + def __post_init__(self): + if self.generation_gpus is None: + self.generation_gpus = [0, 1, 2, 3] # First 4 GPUs for generation + if self.evaluation_gpus is None: + self.evaluation_gpus = [4, 5, 6, 7] # Last 4 GPUs for evaluation + + +class EvaluationWorker: + """Worker process for evaluating kernels on dedicated GPUs""" + + def __init__(self, worker_id: str, gpu_id: int, config: WorkerConfig): + self.worker_id = worker_id + self.gpu_id = gpu_id + self.config = config + self.device = f"cuda:{gpu_id}" + + # Set CUDA device for this worker + os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) + torch.cuda.set_device(0) # After CUDA_VISIBLE_DEVICES, this is device 0 + + # Initialize components + self.redis = redis.Redis(host=config.redis_host, port=config.redis_port, decode_responses=True) + self.kernel_store = KernelStore(config.redis_host, config.redis_port) + + # Reuse LLMBackend's compilation and testing logic + self.llm_backend = LLMBackend() + + print(f"Worker {worker_id} initialized on GPU {gpu_id}") + + def run(self): + """Main worker loop - processes evaluation tasks from Redis queue""" + print(f"Worker {self.worker_id} starting on GPU {self.gpu_id}") + + while True: + try: + # Get task from Redis queue (blocking pop with timeout) + task_data = self.redis.blpop("eval_queue", timeout=5) + + if task_data is None: + continue # Timeout, check for shutdown signal + + # Check for shutdown signal + if self.redis.get(f"shutdown:{self.worker_id}"): + print(f"Worker {self.worker_id} received shutdown signal") + break + + # Parse task + _, task_json = task_data + task = self._parse_task(task_json) + + if task: + # Process the evaluation task + result = self._evaluate_kernel(task) + + # Store result + self.kernel_store.store_kernel_result(task.operation_name, task.kernel_hash, result) + + # Mark task as completed + self.redis.rpush("completed_tasks", json.dumps({ + "worker_id": self.worker_id, + "operation_name": task.operation_name, + "kernel_hash": task.kernel_hash, + "success": result.correctness_passed, + "speedup": result.speedup_factor, + "timestamp": int(time.time()) + })) + + except Exception as e: + print(f"Worker {self.worker_id} error: {e}") + traceback.print_exc() + time.sleep(1) # Brief pause before retrying + + print(f"Worker {self.worker_id} shutdown complete") + + def _parse_task(self, task_json: str) -> Optional[EvaluationTask]: + """Parse JSON task data""" + try: + data = json.loads(task_json) + return EvaluationTask( + operation_name=data["operation_name"], + kernel_code=data["kernel_code"], + kernel_hash=data["kernel_hash"], + test_cases=data.get("test_cases", []), # TODO: Deserialize properly + gpu_id=self.gpu_id, + worker_id=self.worker_id + ) + except Exception as e: + print(f"Failed to parse task: {e}") + return None + + def _evaluate_kernel(self, task: EvaluationTask) -> KernelResult: + """Evaluate a kernel and return detailed results""" + start_time = time.time() + + result = KernelResult( + kernel_code=task.kernel_code, + kernel_hash=task.kernel_hash, + correctness_passed=False, + speedup_factor=0.0, + timestamp=int(time.time()) + ) + + try: + print(f" Worker {self.worker_id} evaluating {task.operation_name}:{task.kernel_hash[:8]}") + + # Save kernel for debugging (even if it fails) + debug_file = f"/tmp/debug_kernel_{task.operation_name}_{task.kernel_hash[:8]}.py" + with open(debug_file, "w") as f: + f.write(f"# Debug kernel for {task.operation_name}\n") + f.write(f"# Hash: {task.kernel_hash}\n\n") + f.write(task.kernel_code) + print(f" Saved debug kernel to: {debug_file}") + + # Use existing LLMBackend testing logic + # TODO: Convert test_cases format properly + dummy_test_cases = [] # Placeholder - need to implement proper test case conversion + + is_correct, feedback = self.llm_backend.test_kernel_correctness( + task.operation_name, task.kernel_code, dummy_test_cases, attempt=1 + ) + + result.compilation_time_ms = int((time.time() - start_time) * 1000) + result.correctness_passed = is_correct + + if not is_correct: + if feedback.get("compilation_error"): + result.error = f"Compilation: {feedback['compilation_error']}" + print(f" ❌ Compilation failed: {feedback['compilation_error']}") + elif feedback.get("test_errors"): + result.error = f"Tests: {feedback['test_errors'][0]['error']}" + print(f" ❌ Tests failed: {feedback['test_errors'][0]['error']}") + else: + result.error = feedback.get("summary", "Unknown error") + print(f" ❌ Failed: {result.error}") + else: + # Basic performance test - just check if it runs + result.speedup_factor = 1.0 # Placeholder + # TODO: Implement proper benchmarking with triton.do_bench + print(f" ✅ Success: {result.correctness_passed}, Speedup: {result.speedup_factor:.2f}x") + + except Exception as e: + result.error = str(e) + print(f" ❌ Error: {e}") + + return result + + +class TaskDispatcher: + """Dispatches evaluation tasks to worker pool""" + + def __init__(self, config: WorkerConfig): + self.config = config + self.redis = redis.Redis(host=config.redis_host, port=config.redis_port, decode_responses=True) + + def submit_evaluation_task(self, operation_name: str, kernel_code: str, test_cases: List = None): + """Submit a kernel for evaluation""" + kernel_hash = hashlib.sha256(kernel_code.encode()).hexdigest()[:16] + + task = { + "operation_name": operation_name, + "kernel_code": kernel_code, + "kernel_hash": kernel_hash, + "test_cases": test_cases or [], + "timestamp": int(time.time()) + } + + # Add to evaluation queue + self.redis.rpush("eval_queue", json.dumps(task)) + + return kernel_hash + + def get_completed_tasks(self, timeout: int = 1) -> List[Dict]: + """Get completed evaluation results""" + completed = [] + + while True: + result = self.redis.blpop("completed_tasks", timeout=timeout) + if result is None: + break + + _, result_json = result + completed.append(json.loads(result_json)) + + return completed + + def wait_for_completion(self, expected_tasks: int, timeout: int = 300) -> List[Dict]: + """Wait for a specific number of tasks to complete""" + completed_tasks = [] + start_time = time.time() + + while len(completed_tasks) < expected_tasks: + if time.time() - start_time > timeout: + print(f"Timeout waiting for tasks. Got {len(completed_tasks)}/{expected_tasks}") + break + + batch = self.get_completed_tasks(timeout=5) + completed_tasks.extend(batch) + + if batch: + print(f"Completed {len(completed_tasks)}/{expected_tasks} tasks") + + return completed_tasks + + +class DistributedWorkerManager: + """Manages the distributed worker processes""" + + def __init__(self, config: WorkerConfig): + self.config = config + self.workers: List[mp.Process] = [] + self.dispatcher = TaskDispatcher(config) + + # Clear any existing shutdown signals + redis_client = redis.Redis(host=config.redis_host, port=config.redis_port) + for gpu_id in config.evaluation_gpus: + redis_client.delete(f"shutdown:eval_worker_{gpu_id}") + + def start_evaluation_workers(self): + """Start evaluation worker processes""" + print(f"Starting {len(self.config.evaluation_gpus)} evaluation workers...") + + for gpu_id in self.config.evaluation_gpus: + worker_id = f"eval_worker_{gpu_id}" + + # Create a serializable config dict instead of passing the full object + config_dict = { + 'redis_host': self.config.redis_host, + 'redis_port': self.config.redis_port, + 'vllm_model_path': self.config.vllm_model_path, + 'generation_gpus': self.config.generation_gpus, + 'evaluation_gpus': self.config.evaluation_gpus, + 'max_workers': self.config.max_workers + } + + # Create worker process + worker_process = mp.Process( + target=DistributedWorkerManager._run_evaluation_worker, + args=(worker_id, gpu_id, config_dict) + ) + + worker_process.start() + self.workers.append(worker_process) + + print(f"Started evaluation worker {worker_id} on GPU {gpu_id}") + + print("All evaluation workers started!") + + @staticmethod + def _run_evaluation_worker(worker_id: str, gpu_id: int, config_dict: dict): + """Run evaluation worker in separate process""" + try: + # Recreate WorkerConfig from dictionary inside worker process + config = WorkerConfig( + redis_host=config_dict['redis_host'], + redis_port=config_dict['redis_port'], + vllm_model_path=config_dict['vllm_model_path'], + generation_gpus=config_dict['generation_gpus'], + evaluation_gpus=config_dict['evaluation_gpus'], + max_workers=config_dict['max_workers'] + ) + + worker = EvaluationWorker(worker_id, gpu_id, config) + worker.run() + except Exception as e: + print(f"Evaluation worker {worker_id} failed: {e}") + traceback.print_exc() + + def shutdown_workers(self): + """Gracefully shutdown all workers""" + print("Shutting down workers...") + + redis_client = redis.Redis(host=self.config.redis_host, port=self.config.redis_port) + + # Send shutdown signals + for gpu_id in self.config.evaluation_gpus: + worker_id = f"eval_worker_{gpu_id}" + redis_client.set(f"shutdown:{worker_id}", "1", ex=60) + + # Wait for workers to finish + for worker in self.workers: + worker.join(timeout=10) + if worker.is_alive(): + print(f"Force terminating worker {worker.pid}") + worker.terminate() + + print("All workers shut down") + + def submit_batch_evaluation(self, kernels: List[Tuple[str, str]]) -> List[str]: + """Submit a batch of kernels for evaluation""" + kernel_hashes = [] + + print(f"Submitting {len(kernels)} kernels for evaluation...") + + for operation_name, kernel_code in kernels: + kernel_hash = self.dispatcher.submit_evaluation_task(operation_name, kernel_code) + kernel_hashes.append(kernel_hash) + + return kernel_hashes + + def wait_for_batch_completion(self, expected_count: int, timeout: int = 300) -> List[Dict]: + """Wait for batch evaluation to complete""" + return self.dispatcher.wait_for_completion(expected_count, timeout) + + +class PrototypeOrchestrator: + """Main orchestrator for the 8-GPU prototype""" + + def __init__(self, config: WorkerConfig): + self.config = config + self.worker_manager = DistributedWorkerManager(config) + self.kernel_store = KernelStore(config.redis_host, config.redis_port) + + async def run_prototype(self, operations: List[str], base_prompts: Dict[str, str]): + """Run the complete 8-GPU prototype""" + + print("🚀 Starting 8-GPU Distributed VLLM Prototype") + print(f"Model: {self.config.vllm_model_path}") + print(f"Generation GPUs: {self.config.generation_gpus}") + print(f"Evaluation GPUs: {self.config.evaluation_gpus}") + print(f"Operations: {operations}") + + try: + # Step 1: Start evaluation workers + self.worker_manager.start_evaluation_workers() + + # Step 2: Initialize VLLM generation + from .vllm_backend import VLLMBackend + vllm_backend = VLLMBackend( + model_path=self.config.vllm_model_path, + tensor_parallel_size=len(self.config.generation_gpus), + redis_host=self.config.redis_host + ) + + # Step 3: Process each operation with rejection sampling + for operation_name in operations: + print(f"\n📝 Processing {operation_name}...") + + base_prompt = base_prompts.get( + operation_name, + f"Generate a high-performance kernel implementation for the {operation_name} operation" + ) + + # Generate candidates using VLLM + print(" Generating kernel candidates...") + candidates = await vllm_backend.generate_kernels_for_operation( + operation_name, base_prompt, num_candidates=20 + ) + + if not candidates: + print(f" ❌ No candidates generated for {operation_name}") + continue + + # Submit candidates for distributed evaluation + print(f" Submitting {len(candidates)} candidates for evaluation...") + kernel_batch = [(operation_name, kernel_code) for kernel_code in candidates] + kernel_hashes = self.worker_manager.submit_batch_evaluation(kernel_batch) + + # Wait for evaluation results + print(" Waiting for evaluation results...") + results = self.worker_manager.wait_for_batch_completion(len(candidates), timeout=120) + + # Analyze results + successful_kernels = [r for r in results if r["success"]] + if successful_kernels: + best_kernel = max(successful_kernels, key=lambda x: x["speedup"]) + print(f" ✅ Best kernel: {best_kernel['speedup']:.2f}x speedup") + else: + print(f" ❌ No successful kernels for {operation_name}") + + # Step 4: Print final summary + print("\n📊 Final Prototype Results:") + for operation_name in operations: + stats = self.kernel_store.get_operation_stats(operation_name) + print(f" {operation_name}:") + print(f" Total attempts: {stats['total_attempts']}") + print(f" Success rate: {stats['success_rate']:.2%}") + print(f" Best speedup: {stats['best_speedup']:.2f}x") + + best_kernels = self.kernel_store.get_best_kernels(operation_name, limit=1) + if best_kernels: + print(f" Best kernel hash: {best_kernels[0]['hash']}") + + finally: + # Cleanup + print("\n🧹 Cleaning up...") + self.worker_manager.shutdown_workers() + print("Prototype complete!") + + +# Utility functions for testing +def create_test_config() -> WorkerConfig: + """Create a test configuration for the 8-GPU prototype""" + return WorkerConfig( + redis_host="localhost", + redis_port=6379, + vllm_model_path="codellama/CodeLlama-7b-Instruct-hf", # Small model for testing + generation_gpus=[0, 1, 2, 3], # First 4 GPUs for VLLM + evaluation_gpus=[4, 5, 6, 7], # Last 4 GPUs for evaluation + max_workers=4 + ) + + +def create_test_prompts() -> Dict[str, str]: + """Create test prompts for common operations""" + return { + "relu": """ +Generate a high-performance Triton kernel for the ReLU activation function. + +Requirements: +- Function name: relu_kernel_impl +- Input: tensor x +- Output: tensor with same shape as x +- Apply max(0, x) element-wise +- Handle arbitrary tensor shapes +- Optimize for GPU memory access patterns + +Example usage: +```python +def relu_kernel_impl(x): + # Your implementation here + return result +``` +""", + + "add": """ +Generate a high-performance kernel for element-wise tensor addition. + +Requirements: +- Function name: add_kernel_impl +- Inputs: tensor a, tensor b +- Output: tensor a + b +- Handle broadcasting if shapes differ +- Optimize for memory coalescing +- Support different dtypes + +Example usage: +```python +def add_kernel_impl(a, b): + # Your implementation here + return result +``` +""" + } \ No newline at end of file diff --git a/BackendBench/prompts.py b/BackendBench/prompts.py index c5835ba..1b0d78d 100644 --- a/BackendBench/prompts.py +++ b/BackendBench/prompts.py @@ -17,7 +17,11 @@ - Handle both args and kwargs properly - Preserve original tensor devices and restore them for outputs -Generate complete, runnable code only - no framework will add device handling wrapper code.""" +Generate complete, runnable code only - no framework will add device handling wrapper code. + +RESPONSE FORMAT: You MUST respond with ONLY a Python code block. No explanations, no reasoning, no text outside the code block. + +Start your response immediately with ```python""" PYTORCH_KERNEL_PROMPT = """Generate a PyTorch implementation for: {op_name} diff --git a/BackendBench/vllm_backend.py b/BackendBench/vllm_backend.py new file mode 100644 index 0000000..2ee9332 --- /dev/null +++ b/BackendBench/vllm_backend.py @@ -0,0 +1,475 @@ +import os +import time +import hashlib +import asyncio +import re +from typing import Dict, List, Optional, Callable +from dataclasses import dataclass +import redis +import json + +try: + from vllm import LLM, SamplingParams + from vllm import AsyncLLMEngine + from vllm.engine.arg_utils import AsyncEngineArgs + VLLM_AVAILABLE = True +except ImportError as e: + VLLM_AVAILABLE = False + print(f"Warning: VLLM not available. Install with: pip install vllm. Error: {e}") + +from .backends import Backend + + +def extract_python_code(text: str) -> str: + """Extract Python code from markdown code blocks or return raw text""" + # Try to find ```python code blocks first (most permissive) + python_match = re.search(r'```python\s*\n(.*?)(?:\n```|$)', text, re.DOTALL) + if python_match: + return python_match.group(1).strip() + + # Try to find generic ``` code blocks + code_match = re.search(r'```\s*\n(.*?)(?:\n```|$)', text, re.DOTALL) + if code_match: + return code_match.group(1).strip() + + # Look for import statements as start of code + import_match = re.search(r'(import torch.*?)(?:\n\n[A-Z]|$)', text, re.DOTALL) + if import_match: + return import_match.group(1).strip() + + # If no code blocks found, return the original text + return text.strip() + + +@dataclass +class KernelResult: + """Result of kernel evaluation""" + kernel_code: str + kernel_hash: str + correctness_passed: bool + speedup_factor: float + error: str = "" + compilation_time_ms: int = 0 + execution_time_us: float = 0.0 + timestamp: int = 0 + + +class KernelStore: + """Redis-based kernel tracking system""" + + def __init__(self, redis_host="localhost", redis_port=6379): + self.redis = redis.Redis(host=redis_host, port=redis_port, decode_responses=True) + + def store_kernel_result(self, operation_name: str, kernel_hash: str, result: KernelResult): + """Store kernel evaluation result""" + key = f"kernel:{operation_name}:{kernel_hash}" + self.redis.hset(key, mapping={ + "code": result.kernel_code, + "correct": str(result.correctness_passed), + "speedup": str(result.speedup_factor), + "error": result.error, + "timestamp": str(int(time.time())), + "compilation_time_ms": str(result.compilation_time_ms), + "execution_time_us": str(result.execution_time_us) + }) + + # Add to operation's kernel list + self.redis.sadd(f"op_kernels:{operation_name}", kernel_hash) + + # Update operation stats + self.update_operation_stats(operation_name, result) + + def get_best_kernels(self, operation_name: str, limit: int = 10) -> List[Dict]: + """Get top performing kernels for an operation""" + kernel_hashes = self.redis.smembers(f"op_kernels:{operation_name}") + + kernels = [] + for kernel_hash in kernel_hashes: + kernel_data = self.redis.hgetall(f"kernel:{operation_name}:{kernel_hash}") + if kernel_data.get("correct") == "True": + kernels.append({ + "hash": kernel_hash, + "speedup": float(kernel_data.get("speedup", 0.0)), + "code": kernel_data.get("code", ""), + "timestamp": int(kernel_data.get("timestamp", 0)) + }) + + # Sort by speedup, return top N + return sorted(kernels, key=lambda x: x["speedup"], reverse=True)[:limit] + + def get_operation_stats(self, operation_name: str) -> Dict: + """Get aggregated stats for an operation""" + stats_key = f"op_stats:{operation_name}" + stats = self.redis.hgetall(stats_key) + + return { + "total_attempts": int(stats.get("total", 0)), + "successful_attempts": int(stats.get("successful", 0)), + "best_speedup": float(stats.get("best_speedup", 0.0)), + "avg_speedup": float(stats.get("avg_speedup", 0.0)), + "success_rate": float(stats.get("success_rate", 0.0)) + } + + def update_operation_stats(self, operation_name: str, result: KernelResult): + """Update running statistics for an operation""" + stats_key = f"op_stats:{operation_name}" + + # Atomic update using Lua script + lua_script = """ + local key = KEYS[1] + local correct = ARGV[1] == "True" + local speedup = tonumber(ARGV[2]) or 0 + + local total = tonumber(redis.call('HGET', key, 'total') or 0) + local successful = tonumber(redis.call('HGET', key, 'successful') or 0) + local best_speedup = tonumber(redis.call('HGET', key, 'best_speedup') or 0) + local total_speedup = tonumber(redis.call('HGET', key, 'total_speedup') or 0) + + total = total + 1 + if correct then + successful = successful + 1 + total_speedup = total_speedup + speedup + if speedup > best_speedup then + best_speedup = speedup + end + end + + local avg_speedup = successful > 0 and (total_speedup / successful) or 0 + local success_rate = total > 0 and (successful / total) or 0 + + redis.call('HMSET', key, + 'total', total, + 'successful', successful, + 'best_speedup', best_speedup, + 'total_speedup', total_speedup, + 'avg_speedup', avg_speedup, + 'success_rate', success_rate + ) + """ + + self.redis.eval(lua_script, 1, stats_key, + str(result.correctness_passed), + str(result.speedup_factor)) + + def get_failure_context(self, operation_name: str, limit: int = 5) -> str: + """Get context about recent failures to improve prompts""" + kernel_hashes = list(self.redis.smembers(f"op_kernels:{operation_name}"))[-limit:] + + failed_errors = [] + for kernel_hash in kernel_hashes: + kernel_data = self.redis.hgetall(f"kernel:{operation_name}:{kernel_hash}") + if kernel_data.get("correct") != "True" and kernel_data.get("error"): + failed_errors.append(kernel_data["error"]) + + if failed_errors: + return f"Recent failures: {'; '.join(failed_errors[:3])}" + return "" + + +class VLLMGenerationWorker: + """VLLM-based kernel generation worker""" + + def __init__(self, model_path: str, tensor_parallel_size: int = 1, max_model_len: int = 4096): + if not VLLM_AVAILABLE: + raise RuntimeError("VLLM not available. Install with: pip install vllm") + + self.model_path = model_path + self.tensor_parallel_size = tensor_parallel_size + + # Initialize VLLM AsyncLLMEngine + engine_args = AsyncEngineArgs( + model=model_path, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + gpu_memory_utilization=0.85, # Lower for larger models like Qwen3-30B + enforce_eager=True, # Avoid graph compilation overhead + trust_remote_code=True, # Required for Qwen models + dtype="bfloat16" # Use bfloat16 for better memory efficiency + ) + + self.engine = AsyncLLMEngine.from_engine_args(engine_args) + + # Sampling parameters for kernel generation + self.sampling_params = SamplingParams( + temperature=0.7, + top_p=0.9, + max_tokens=2048, + n=1 # Generate one candidate at a time + ) + + async def generate_kernel_batch(self, prompts: List[str], num_candidates: int = 10) -> List[List[str]]: + """Generate multiple kernel candidates for each prompt""" + all_results = [] + + for prompt in prompts: + candidates = [] + + # Generate candidates one by one (more reliable than n=num_candidates) + for i in range(num_candidates): + sampling_params = SamplingParams( + temperature=0.3, + top_p=0.9, + max_tokens=4096, # Increase to ensure complete generation + n=1 # Generate one at a time + ) + + request_id = f"req_{hash(prompt)}_{time.time()}_{i}" + + # Generate single candidate + final_output = None + async for request_output in self.engine.generate(prompt, sampling_params, request_id): + final_output = request_output + + # Extract kernel code + if final_output and final_output.outputs and final_output.outputs[0].text.strip(): + raw_text = final_output.outputs[0].text.strip() + kernel_code = extract_python_code(raw_text) + candidates.append(kernel_code) + print(f" Generated candidate {i+1}/{num_candidates} ({len(kernel_code)} chars)") + else: + print(f" Failed to generate candidate {i+1}/{num_candidates}") + + print(f" Successfully generated {len(candidates)}/{num_candidates} candidates") + all_results.append(candidates) + + return all_results + + +class SimpleRepromptPolicy: + """Simple reprompting policy based on kernel performance""" + + def __init__(self, kernel_store: KernelStore): + self.store = kernel_store + + def should_generate_more(self, operation_name: str) -> bool: + """Simple policy: generate more if we haven't found good kernels""" + stats = self.store.get_operation_stats(operation_name) + + # Keep going if success rate low and haven't tried much + if stats["success_rate"] < 0.3 and stats["total_attempts"] < 50: + return True + + # Keep going if best speedup is poor + if stats["best_speedup"] < 1.5 and stats["total_attempts"] < 30: + return True + + # Stop if we have good results + return stats["total_attempts"] < 10 + + def get_adaptive_prompt(self, operation_name: str, base_prompt: str) -> str: + """Create context-aware prompts based on previous failures""" + failure_context = self.store.get_failure_context(operation_name) + + if failure_context: + enhanced_prompt = f"""{base_prompt} + +CRITICAL: Previous attempts failed with these errors: +{failure_context} +""" + else: + enhanced_prompt = base_prompt + + return enhanced_prompt + + +class VLLMBackend(Backend): + """VLLM-powered distributed kernel generation backend""" + + def __init__(self, model_path: str, tensor_parallel_size: int = 4, redis_host: str = "localhost"): + super().__init__("vllm") + + self.model_path = model_path + self.tensor_parallel_size = tensor_parallel_size + + # Initialize components + self.kernel_store = KernelStore(redis_host=redis_host) + self.reprompt_policy = SimpleRepromptPolicy(self.kernel_store) + self.compiled_kernels: Dict[str, Callable] = {} + + # Worker will be initialized when needed + self.generation_worker: Optional[VLLMGenerationWorker] = None + + print(f"VLLMBackend initialized with model: {model_path}") + print(f"Tensor parallel size: {tensor_parallel_size}") + + async def _ensure_worker_initialized(self): + """Lazy initialization of VLLM worker""" + if self.generation_worker is None: + print("Initializing VLLM generation worker...") + self.generation_worker = VLLMGenerationWorker( + self.model_path, + self.tensor_parallel_size + ) + print("VLLM worker ready!") + + async def generate_kernels_for_operation(self, operation_name: str, base_prompt: str, num_candidates: int = 10) -> List[str]: + """Generate kernel candidates for a specific operation""" + await self._ensure_worker_initialized() + + # Get adaptive prompt based on previous failures + adaptive_prompt = self.reprompt_policy.get_adaptive_prompt(operation_name, base_prompt) + + # Generate candidates + results = await self.generation_worker.generate_kernel_batch([adaptive_prompt], num_candidates) + + return results[0] if results else [] + + def evaluate_and_store_kernel(self, kernel_code: str, operation_name: str, test_cases: List) -> KernelResult: + """Evaluate a kernel and store results""" + kernel_hash = hashlib.sha256(kernel_code.encode()).hexdigest()[:16] + + # Check if already evaluated + existing = self.kernel_store.redis.hgetall(f"kernel:{operation_name}:{kernel_hash}") + if existing: + return KernelResult( + kernel_code=existing["code"], + kernel_hash=kernel_hash, + correctness_passed=existing["correct"] == "True", + speedup_factor=float(existing["speedup"]), + error=existing.get("error", ""), + compilation_time_ms=int(existing.get("compilation_time_ms", 0)), + execution_time_us=float(existing.get("execution_time_us", 0.0)), + timestamp=int(existing.get("timestamp", 0)) + ) + + result = KernelResult( + kernel_code=kernel_code, + kernel_hash=kernel_hash, + correctness_passed=False, + speedup_factor=0.0, + timestamp=int(time.time()) + ) + + try: + start_time = time.time() + + # Use existing LLMBackend compilation logic + from .backends import LLMBackend + temp_backend = LLMBackend() + + # Test correctness + is_correct, feedback = temp_backend.test_kernel_correctness( + operation_name, kernel_code, test_cases, attempt=1 + ) + + result.compilation_time_ms = int((time.time() - start_time) * 1000) + result.correctness_passed = is_correct + + if not is_correct: + if feedback.get("compilation_error"): + result.error = f"Compilation: {feedback['compilation_error']}" + elif feedback.get("test_errors"): + result.error = f"Tests: {feedback['test_errors'][0]['error']}" + else: + result.error = feedback.get("summary", "Unknown error") + else: + # TODO: Add performance benchmarking + result.speedup_factor = 1.0 # Placeholder + + except Exception as e: + result.error = str(e) + + # Store result + self.kernel_store.store_kernel_result(operation_name, kernel_hash, result) + + return result + + async def process_operation_with_rejection_sampling(self, operation_name: str, base_prompt: str, test_cases: List) -> Optional[str]: + """Process an operation with rejection sampling until we find a good kernel""" + + print(f"\n🚀 Processing operation: {operation_name}") + + while self.reprompt_policy.should_generate_more(operation_name): + stats = self.kernel_store.get_operation_stats(operation_name) + print(f" Current stats: {stats['total_attempts']} attempts, {stats['success_rate']:.2f} success rate, {stats['best_speedup']:.2f}x best speedup") + + # Generate candidates + print(f" Generating {10} kernel candidates...") + candidates = await self.generate_kernels_for_operation(operation_name, base_prompt, num_candidates=10) + + # Evaluate each candidate + best_result = None + for i, kernel_code in enumerate(candidates): + print(f" Evaluating candidate {i+1}/{len(candidates)}...") + result = self.evaluate_and_store_kernel(kernel_code, operation_name, test_cases) + + if result.correctness_passed and (best_result is None or result.speedup_factor > best_result.speedup_factor): + best_result = result + + if best_result: + print(f" ✅ Found working kernel with {best_result.speedup_factor:.2f}x speedup") + return best_result.kernel_code + else: + print(f" ❌ No working kernels in this batch") + + # Get best kernel found so far + best_kernels = self.kernel_store.get_best_kernels(operation_name, limit=1) + if best_kernels: + print(f" 📋 Using best kernel found: {best_kernels[0]['speedup']:.2f}x speedup") + return best_kernels[0]["code"] + + print(f" ⚠️ No working kernels found for {operation_name}") + return None + + def add_kernel(self, op, kernel_code: str, op_name: str): + """Add a compiled kernel to the backend (compatibility with existing interface)""" + from .backends import LLMBackend + temp_backend = LLMBackend() + compiled_kernel = temp_backend.compile_kernel_from_string(kernel_code, op_name, attempt=1) + self.compiled_kernels[op] = compiled_kernel + + def __getitem__(self, key): + if key in self.compiled_kernels: + return self.compiled_kernels[key] + raise KeyError(f"No kernel implementation found for {key}") + + def __contains__(self, key): + return key in self.compiled_kernels + + +class DistributedVLLMOrchestrator: + """Orchestrates distributed VLLM kernel generation across multiple workers""" + + def __init__(self, config: Dict): + self.config = config + self.kernel_store = KernelStore(redis_host=config.get("redis_host", "localhost")) + + # Initialize workers based on config + self.generation_workers = [] + self.evaluation_workers = [] + + async def run_8gpu_prototype(self, operations: List[str], base_prompts: Dict[str, str]): + """Run the 8-GPU prototype with distributed workers""" + + print("🔥 Starting 8-GPU VLLM Backend Prototype") + print(f"Operations to process: {len(operations)}") + + # Initialize VLLM backend (uses GPUs 0-3 for generation) + vllm_backend = VLLMBackend( + model_path=self.config["model_path"], + tensor_parallel_size=4, # Use 4 GPUs for VLLM + redis_host=self.config.get("redis_host", "localhost") + ) + + # Process each operation + for operation_name in operations: + base_prompt = base_prompts.get(operation_name, f"Implement a kernel for {operation_name}") + + # TODO: Get actual test cases from suite + dummy_test_cases = [] # Placeholder + + # Process with rejection sampling + best_kernel = await vllm_backend.process_operation_with_rejection_sampling( + operation_name, base_prompt, dummy_test_cases + ) + + if best_kernel: + print(f"✅ Successfully generated kernel for {operation_name}") + else: + print(f"❌ Failed to generate kernel for {operation_name}") + + # Print final statistics + print("\n📊 Final Results:") + for operation_name in operations: + stats = vllm_backend.kernel_store.get_operation_stats(operation_name) + print(f" {operation_name}: {stats['successful_attempts']}/{stats['total_attempts']} success, {stats['best_speedup']:.2f}x best speedup") \ No newline at end of file diff --git a/VLLM_PROTOTYPE_README.md b/VLLM_PROTOTYPE_README.md new file mode 100644 index 0000000..b261ad7 --- /dev/null +++ b/VLLM_PROTOTYPE_README.md @@ -0,0 +1,151 @@ +# VLLM Backend 8-GPU Prototype + +This prototype demonstrates a scalable self-hosted VLLM deployment for massively parallel kernel generation using rejection sampling. The system is designed to scale from 8 GPUs to 1000+ GPUs. + +## Architecture Overview + +### 🏗️ System Components + +``` +┌─────────────────────────────────────────────────────────────┐ +│ Redis Task Queue │ +│ ┌─────────────────┐ ┌─────────────────┐ ┌──────────────┐ │ +│ │ Generation │ │ Evaluation │ │ Results │ │ +│ │ Queue │ │ Queue │ │ Storage │ │ +│ └─────────────────┘ └─────────────────┘ └──────────────┘ │ +└─────────────────────────────────────────────────────────────┘ + │ │ │ + ▼ ▼ ▼ +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐ +│ VLLM Generation │ │ Evaluation │ │ Aggregation │ +│ Workers │ │ Workers │ │ & Metrics │ +│ │ │ │ │ │ +│ GPUs 0-3 │ │ GPUs 4-7 │ │ Redis KV Store │ +│ 4-way TP │ │ 1 GPU each │ │ │ +│ Async batching │ │ Isolated eval │ │ Performance │ +│ │ │ │ │ tracking │ +└─────────────────┘ └─────────────────┘ └─────────────────┘ +``` + +## Quick Start + +### Prerequisites + +```bash +pip install -r requirements.txt +pip install -r requirements-vllm.txt +redis-server +``` + +### Running the Prototype + +```bash +# Basic usage +python scripts/run_vllm_prototype.py + +# Custom model and operations +python scripts/run_vllm_prototype.py \ + --model "Qwen/Qwen2.5-14B-Instruct" \ + --operations "relu,add,mul,sigmoid" \ + --candidates-per-op 50 + +# Advanced usage with custom Redis +python scripts/run_vllm_prototype.py \ + --model "Qwen/Qwen2.5-7B-Instruct" \ + --operations "matmul,softmax,layernorm" \ + --redis-host "10.0.0.100" \ + --candidates-per-op 100 +``` + +## Detailed Architecture + +### 1. VLLM Generation Workers + +**Purpose**: Generate kernel candidates using distributed VLLM inference + +- **GPU Usage**: GPUs 0-3 with 4-way tensor parallelism +- **Concurrency**: Async batching for high throughput +- **Output**: 10-50 kernel candidates per operation + +TODO: Measure actual count + +### 2. Evaluation Workers + +**Purpose**: Test kernel correctness and performance in isolation + +- **GPU Usage**: GPUs 4-7, one worker per GPU +- **Process Isolation**: Each worker runs in separate process +- **Testing**: Compilation → Correctness → Performance benchmarking + +### 3. Kernel Tracking System + +**Redis Schema**: +``` +kernel:{operation}:{hash} → {code, correct, speedup, error, timestamp} +op_kernels:{operation} → Set of kernel hashes +op_stats:{operation} → {total, successful, best_speedup, avg_speedup} +eval_queue → JSON task queue +completed_tasks → JSON results queue +``` + +### 4. Intelligent Reprompting + +The system learns from failures to improve generation: + +## Scaling to 1000 GPUs + +TODO + +```python +SCALE_CONFIG = { + "generation_workers": { + "count": 125, + "tensor_parallel_size": 8, + "gpus_per_worker": 8 + }, + "evaluation_workers": { + "count": 500, + "gpus_per_worker": 1 + }, + "redis_cluster": { + "nodes": 10, + "shards": 16 + }, + "candidates_per_operation": 1000 +} +``` + +## Performance Expectations + +TODO: MEasure respectively how many kernels we generated, evaluated and end to end speedup to find a better than torch eager kernel + +## Monitoring and Debugging + +### Real-time Metrics + +```bash +# Monitor Redis queues +redis-cli LLEN eval_queue +redis-cli LLEN completed_tasks + +# Check worker status +redis-cli KEYS "shutdown:*" + +# Operation statistics +redis-cli HGETALL "op_stats:relu" +``` + +### Performance Analysis + +```python +from BackendBench.vllm_backend import KernelStore + +store = KernelStore() +stats = store.get_operation_stats("relu") +print(f"Success rate: {stats['success_rate']:.2%}") +print(f"Best speedup: {stats['best_speedup']:.2f}x") + +best_kernels = store.get_best_kernels("relu", limit=5) +for kernel in best_kernels: + print(f"Hash: {kernel['hash']}, Speedup: {kernel['speedup']:.2f}x") +``` \ No newline at end of file diff --git a/requirements-vllm.txt b/requirements-vllm.txt new file mode 100644 index 0000000..dd80dad --- /dev/null +++ b/requirements-vllm.txt @@ -0,0 +1,7 @@ +vllm>=0.3.0 +redis>=4.5.0 +asyncio-mqtt>=0.11.0 +psutil>=5.9.0 +rich>=13.0.0 +click>=8.1.0 +prometheus-client>=0.16.0 \ No newline at end of file diff --git a/scripts/run_vllm_prototype.py b/scripts/run_vllm_prototype.py new file mode 100755 index 0000000..87f7f3a --- /dev/null +++ b/scripts/run_vllm_prototype.py @@ -0,0 +1,200 @@ +#!/usr/bin/env python3 +""" +8-GPU VLLM Backend Prototype Runner + +This script demonstrates the distributed VLLM kernel generation system with: +- GPUs 0-3: VLLM generation workers (4-way tensor parallelism) +- GPUs 4-7: Evaluation workers (1 GPU each for isolated testing) +- Redis: Task queue and results storage +- Rejection sampling: Generate many candidates, keep the best + +Usage: + python scripts/run_vllm_prototype.py --model codellama/CodeLlama-7b-Instruct-hf --operations relu,add,mul + +Requirements: + - 8 GPUs available + - Redis server running + - VLLM installed: pip install vllm + - PyTorch with CUDA support +""" + +import asyncio +import argparse +import sys +import os +import time +from typing import List, Dict + +# Add BackendBench to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + +from BackendBench.distributed_workers import ( + PrototypeOrchestrator, + WorkerConfig, + create_test_config, + create_test_prompts +) + + +def check_prerequisites(): + """Check that all prerequisites are available""" + print("🔍 Checking prerequisites...") + + # Check GPU availability + try: + import torch + if not torch.cuda.is_available(): + print("❌ CUDA not available") + return False + + gpu_count = torch.cuda.device_count() + if gpu_count < 8: + print(f"❌ Need 8 GPUs, found {gpu_count}") + return False + + print(f"✅ Found {gpu_count} GPUs") + + # Print GPU info + for i in range(min(8, gpu_count)): + gpu_name = torch.cuda.get_device_name(i) + memory_gb = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {gpu_name} ({memory_gb:.1f}GB)") + + except ImportError: + print("❌ PyTorch not available") + return False + + # Check VLLM + try: + import vllm + print("✅ VLLM available") + except ImportError: + print("❌ VLLM not available. Install with: pip install vllm") + return False + + # Check Redis + try: + import redis + r = redis.Redis(host='localhost', port=6379, decode_responses=True) + r.ping() + print("✅ Redis server available") + except Exception as e: + print(f"❌ Redis not available: {e}") + print(" Start Redis with: redis-server") + return False + + return True + + +def parse_operations(operations_str: str) -> List[str]: + """Parse comma-separated operations string""" + return [op.strip() for op in operations_str.split(",") if op.strip()] + + +def create_prompt_for_operation(op_name: str) -> str: + """Create prompt using existing KernelTemplateManager with simple pattern""" + from BackendBench.kernel_templates import KernelTemplateManager + + # Use simple signature and description patterns + op_signature = f"{op_name}(input: Tensor, *args, **kwargs) -> Tensor" + op_description = f"Apply {op_name} operation element-wise" + + # Use existing template manager + template_manager = KernelTemplateManager() + return template_manager.create_prompt(op_name, op_signature, op_description, framework="triton") + + +async def main(): + parser = argparse.ArgumentParser(description="Run 8-GPU VLLM Backend Prototype") + parser.add_argument( + "--model", + default="deepseek-ai/DeepSeek-R1-Distill-Llama-70B", + help="VLLM model path (default: deepseek-ai/DeepSeek-R1-Distill-Llama-70B)" + ) + parser.add_argument( + "--operations", + default="relu,add,mul", + help="Comma-separated list of operations to test (default: relu,add,mul)" + ) + parser.add_argument( + "--redis-host", + default="localhost", + help="Redis host (default: localhost)" + ) + parser.add_argument( + "--candidates-per-op", + type=int, + default=20, + help="Number of kernel candidates to generate per operation (default: 20)" + ) + parser.add_argument( + "--skip-checks", + action="store_true", + help="Skip prerequisite checks (for testing)" + ) + + args = parser.parse_args() + + print("🚀 VLLM Backend 8-GPU Prototype") + print("=" * 50) + print(f"Model: {args.model}") + print(f"Operations: {args.operations}") + print(f"Candidates per operation: {args.candidates_per_op}") + print(f"Redis host: {args.redis_host}") + print() + + # Check prerequisites + if not args.skip_checks and not check_prerequisites(): + print("\n❌ Prerequisites not met. Exiting.") + return 1 + + # Parse operations + operations = parse_operations(args.operations) + if not operations: + print("❌ No valid operations specified") + return 1 + + print(f"✅ Will process {len(operations)} operations: {operations}") + + # Create configuration + config = WorkerConfig( + redis_host=args.redis_host, + vllm_model_path=args.model, + generation_gpus=[0, 1, 2, 3], # First 4 GPUs for VLLM + evaluation_gpus=[4, 5, 6, 7], # Last 4 GPUs for evaluation + ) + + # Create prompts for each operation using the existing simple pattern + prompts = {} + for op in operations: + prompts[op] = create_prompt_for_operation(op) + + try: + # Create and run orchestrator + print("\n🎯 Initializing orchestrator...") + orchestrator = PrototypeOrchestrator(config) + + start_time = time.time() + await orchestrator.run_prototype(operations, prompts) + end_time = time.time() + + print(f"\n✅ Prototype completed in {end_time - start_time:.1f} seconds") + return 0 + + except KeyboardInterrupt: + print("\n🛑 Interrupted by user") + return 1 + except Exception as e: + print(f"\n❌ Prototype failed: {e}") + import traceback + traceback.print_exc() + return 1 + + +if __name__ == "__main__": + # Set up proper asyncio event loop + if sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) + + exit_code = asyncio.run(main()) + sys.exit(exit_code) \ No newline at end of file diff --git a/scripts/validate_vllm_setup.py b/scripts/validate_vllm_setup.py new file mode 100755 index 0000000..48be292 --- /dev/null +++ b/scripts/validate_vllm_setup.py @@ -0,0 +1,265 @@ +#!/usr/bin/env python3 +""" +Validation script for VLLM Backend setup + +This script validates that all components are properly installed and configured +for the 8-GPU VLLM prototype without actually running the full system. + +Usage: + python scripts/validate_vllm_setup.py +""" + +import sys +import os + +# Add BackendBench to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) + + +def test_imports(): + """Test that all required modules can be imported""" + print("🔍 Testing imports...") + + try: + import torch + print(f"✅ PyTorch {torch.__version__}") + + if torch.cuda.is_available(): + print(f" CUDA {torch.version.cuda} available") + print(f" {torch.cuda.device_count()} GPU(s) detected") + else: + print("⚠️ CUDA not available") + + except ImportError as e: + print(f"❌ PyTorch import failed: {e}") + return False + + try: + import redis + print(f"✅ Redis client available") + except ImportError as e: + print(f"❌ Redis import failed: {e}") + return False + + try: + import vllm + print(f"✅ VLLM {vllm.__version__} available") + except ImportError as e: + print(f"❌ VLLM import failed: {e}") + print(" Install with: pip install vllm") + return False + + try: + from BackendBench.vllm_backend import VLLMBackend, KernelStore + print("✅ VLLM Backend modules imported successfully") + except ImportError as e: + print(f"❌ VLLM Backend import failed: {e}") + return False + + try: + from BackendBench.distributed_workers import PrototypeOrchestrator + print("✅ Distributed workers imported successfully") + except ImportError as e: + print(f"❌ Distributed workers import failed: {e}") + return False + + return True + + +def test_redis_connection(): + """Test Redis server connection""" + print("\n🔍 Testing Redis connection...") + + try: + import redis + r = redis.Redis(host='localhost', port=6379, decode_responses=True) + + # Test basic operations + r.ping() + r.set("test_key", "test_value") + value = r.get("test_key") + r.delete("test_key") + + if value == "test_value": + print("✅ Redis server working correctly") + return True + else: + print("❌ Redis read/write test failed") + return False + + except Exception as e: + print(f"❌ Redis connection failed: {e}") + print(" Make sure Redis server is running: redis-server") + return False + + +def test_kernel_store(): + """Test kernel store functionality""" + print("\n🔍 Testing kernel store...") + + try: + from BackendBench.vllm_backend import KernelStore, KernelResult + + store = KernelStore() + + # Clean up any existing test data first + store.redis.delete("kernel:test_op_validation:test_hash_123") + store.redis.delete("op_kernels:test_op_validation") + store.redis.delete("op_stats:test_op_validation") + + # Test storing a kernel result (use unique operation name for validation) + test_result = KernelResult( + kernel_code="def test(): return 42", + kernel_hash="test_hash_123", + correctness_passed=True, + speedup_factor=1.5, + timestamp=1234567890 + ) + + store.store_kernel_result("test_op_validation", "test_hash_123", test_result) + + # Test retrieving stats + stats = store.get_operation_stats("test_op_validation") + + if stats["total_attempts"] == 1 and stats["best_speedup"] == 1.5: + print("✅ Kernel store working correctly") + + # Cleanup + store.redis.delete("kernel:test_op_validation:test_hash_123") + store.redis.delete("op_kernels:test_op_validation") + store.redis.delete("op_stats:test_op_validation") + + return True + else: + print(f"❌ Kernel store test failed: {stats}") + # Still cleanup even on failure + store.redis.delete("kernel:test_op_validation:test_hash_123") + store.redis.delete("op_kernels:test_op_validation") + store.redis.delete("op_stats:test_op_validation") + return False + + except Exception as e: + print(f"❌ Kernel store test failed: {e}") + return False + + +def test_gpu_allocation(): + """Test GPU allocation strategy""" + print("\n🔍 Testing GPU allocation...") + + try: + import torch + + if not torch.cuda.is_available(): + print("⚠️ Skipping GPU tests - CUDA not available") + return True + + gpu_count = torch.cuda.device_count() + print(f" Available GPUs: {gpu_count}") + + if gpu_count < 8: + print(f"⚠️ Only {gpu_count} GPUs available (need 8 for full prototype)") + print(" Prototype will use available GPUs with reduced parallelism") + + # Test basic GPU operations + for i in range(min(gpu_count, 8)): + device = f"cuda:{i}" + try: + x = torch.randn(100, device=device) + y = x * 2 + assert y.device.type == "cuda" + print(f" ✅ GPU {i}: {torch.cuda.get_device_name(i)}") + except Exception as e: + print(f" ❌ GPU {i} test failed: {e}") + return False + + print("✅ GPU allocation tests passed") + return True + + except Exception as e: + print(f"❌ GPU allocation test failed: {e}") + return False + + +def test_vllm_initialization(): + """Test VLLM initialization (without loading large model)""" + print("\n🔍 Testing VLLM initialization...") + + try: + from vllm import AsyncLLMEngine + from vllm.engine.arg_utils import AsyncEngineArgs + + # Test with minimal configuration (don't actually load model) + print("✅ VLLM classes imported successfully") + print(" (Skipping actual model loading in validation)") + + return True + + except Exception as e: + print(f"❌ VLLM initialization test failed: {e}") + return False + + +def print_system_info(): + """Print system information""" + print("\n📊 System Information:") + + import platform + print(f" OS: {platform.system()} {platform.release()}") + print(f" Python: {sys.version.split()[0]}") + + try: + import torch + print(f" PyTorch: {torch.__version__}") + if torch.cuda.is_available(): + print(f" CUDA: {torch.version.cuda}") + for i in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(i) + memory = torch.cuda.get_device_properties(i).total_memory / 1e9 + print(f" GPU {i}: {name} ({memory:.1f}GB)") + except: + pass + + try: + import vllm + print(f" VLLM: {vllm.__version__}") + except: + pass + + +def main(): + print("🚀 VLLM Backend Setup Validation") + print("=" * 50) + + all_tests_passed = True + + # Run all validation tests + tests = [ + test_imports, + test_redis_connection, + test_kernel_store, + test_gpu_allocation, + test_vllm_initialization + ] + + for test in tests: + if not test(): + all_tests_passed = False + + print_system_info() + + print("\n" + "=" * 50) + if all_tests_passed: + print("✅ All validation tests passed!") + print("🚀 Ready to run VLLM prototype:") + print(" python scripts/run_vllm_prototype.py") + return 0 + else: + print("❌ Some validation tests failed.") + print(" Please fix the issues above before running the prototype.") + return 1 + + +if __name__ == "__main__": + exit_code = main() + sys.exit(exit_code) \ No newline at end of file