|
| 1 | +import datetime |
| 2 | +import importlib.util |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import sys |
| 6 | +import torch |
| 7 | +import traceback |
| 8 | +from typing import Callable, Dict, List |
| 9 | + |
| 10 | +from .base import Backend |
| 11 | + |
| 12 | +logger = logging.getLogger(__name__) |
| 13 | + |
| 14 | + |
| 15 | +class LLMRelayBackend(Backend): |
| 16 | + """ |
| 17 | + Backend that uses LLMKernelGenerator to communicate with local plugboard server. |
| 18 | + This backend will eventually replace the LLMBackend that uses direct Anthropic API calls. |
| 19 | + """ |
| 20 | + |
| 21 | + def __init__(self, model: str = "gcp-claude-4-sonnet") -> None: |
| 22 | + super().__init__("llm-relay") |
| 23 | + self.compiled_kernels: Dict[str, Callable] = {} |
| 24 | + self.model = model |
| 25 | + |
| 26 | + # Create generated_kernels directory |
| 27 | + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
| 28 | + self.kernels_dir = f"generated_kernels/llm_relay_run_{timestamp}" |
| 29 | + os.makedirs(self.kernels_dir, exist_ok=True) |
| 30 | + |
| 31 | + # Create README for this run |
| 32 | + readme_path = os.path.join(self.kernels_dir, "README.md") |
| 33 | + with open(readme_path, "w") as f: |
| 34 | + f.write( |
| 35 | + f"""# Generated Kernels - LLM Relay - {timestamp} |
| 36 | +
|
| 37 | +This directory contains PyTorch/Triton kernels generated by the LLM Relay Backend. |
| 38 | +
|
| 39 | +## Run Info |
| 40 | +- Timestamp: {timestamp} |
| 41 | +- Backend: LLM Relay |
| 42 | +- Model: {model} |
| 43 | +- Server: Local plugboard server (localhost:11434) |
| 44 | +
|
| 45 | +## Files |
| 46 | +Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation, including: |
| 47 | +- All necessary imports |
| 48 | +- Triton kernel implementation (if applicable) |
| 49 | +- Wrapper function that matches PyTorch operation signature |
| 50 | +
|
| 51 | +## Server Setup |
| 52 | +This backend requires the plugboard server to be running: |
| 53 | +``` |
| 54 | +buck run @//mode/inplace run_plugboard_server -- --model gcp-claude-4-sonnet --pipeline usecase-dev-ai-user |
| 55 | +``` |
| 56 | +
|
| 57 | +## Usage |
| 58 | +You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced. |
| 59 | +""" |
| 60 | + ) |
| 61 | + |
| 62 | + logger.info(f"Saving LLM Relay generated kernels to: {self.kernels_dir}") |
| 63 | + |
| 64 | + def compile_kernel_from_string( |
| 65 | + self, kernel_code: str, op_name: str, attempt: int = 1 |
| 66 | + ) -> Callable: |
| 67 | + """Compile a kernel from string code and return a callable.""" |
| 68 | + try: |
| 69 | + is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code |
| 70 | + |
| 71 | + if is_triton: |
| 72 | + full_code = self._prepare_triton_code(kernel_code) |
| 73 | + else: |
| 74 | + full_code = self._prepare_torch_code(kernel_code) |
| 75 | + |
| 76 | + kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py") |
| 77 | + with open(kernel_file, "w") as f: |
| 78 | + f.write(full_code) |
| 79 | + |
| 80 | + logger.debug(f"Saved kernel to: {kernel_file}") |
| 81 | + |
| 82 | + spec = importlib.util.spec_from_file_location(f"kernel_{op_name}", kernel_file) |
| 83 | + module = importlib.util.module_from_spec(spec) |
| 84 | + spec.loader.exec_module(module) |
| 85 | + |
| 86 | + kernel_func = self._find_kernel_function(module, op_name) |
| 87 | + |
| 88 | + return kernel_func |
| 89 | + |
| 90 | + except Exception as e: |
| 91 | + raise RuntimeError(f"Failed to compile kernel for {op_name}: {str(e)}") |
| 92 | + |
| 93 | + def _prepare_triton_code(self, kernel_code: str) -> str: |
| 94 | + """Prepare Triton kernel code with necessary imports.""" |
| 95 | + imports = """ |
| 96 | +import torch |
| 97 | +import triton |
| 98 | +import triton.language as tl |
| 99 | +""" |
| 100 | + if "import torch" not in kernel_code: |
| 101 | + kernel_code = imports + kernel_code |
| 102 | + return kernel_code |
| 103 | + |
| 104 | + def _prepare_torch_code(self, kernel_code: str) -> str: |
| 105 | + """Prepare regular PyTorch kernel code with necessary imports.""" |
| 106 | + imports = """ |
| 107 | +import torch |
| 108 | +import torch.nn.functional as F |
| 109 | +""" |
| 110 | + if "import torch" not in kernel_code: |
| 111 | + kernel_code = imports + kernel_code |
| 112 | + return kernel_code |
| 113 | + |
| 114 | + def _find_kernel_function(self, module, op_name: str) -> Callable: |
| 115 | + """Find the main kernel function in the compiled module.""" |
| 116 | + expected_name = f"{op_name}_kernel_impl" |
| 117 | + |
| 118 | + if hasattr(module, expected_name): |
| 119 | + return getattr(module, expected_name) |
| 120 | + |
| 121 | + available_functions = [ |
| 122 | + name |
| 123 | + for name in dir(module) |
| 124 | + if callable(getattr(module, name)) and not name.startswith("_") |
| 125 | + ] |
| 126 | + |
| 127 | + raise ValueError( |
| 128 | + f"Expected function '{expected_name}' not found in kernel code for {op_name}. " |
| 129 | + f"Available functions: {available_functions}. " |
| 130 | + f"Please ensure the LLM generated code follows the naming convention: {op_name}_kernel_impl" |
| 131 | + ) |
| 132 | + |
| 133 | + def add_kernel(self, op, kernel_code: str, op_name: str): |
| 134 | + """Add a kernel implementation for a specific operator.""" |
| 135 | + compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1) |
| 136 | + self.compiled_kernels[op] = compiled_kernel |
| 137 | + |
| 138 | + def test_kernel_correctness( |
| 139 | + self, op, kernel_code: str, test_cases: List, attempt: int = 1 |
| 140 | + ) -> tuple[bool, Dict]: |
| 141 | + """Test kernel correctness and return detailed feedback.""" |
| 142 | + op_str = str(op) |
| 143 | + if "aten." in op_str: |
| 144 | + op_name = op_str.split("aten.")[-1].split(".")[0] |
| 145 | + else: |
| 146 | + op_name = op_str.split(".")[-1] |
| 147 | + |
| 148 | + feedback_info = { |
| 149 | + "compilation_error": None, |
| 150 | + "test_errors": [], |
| 151 | + "summary": None, |
| 152 | + } |
| 153 | + |
| 154 | + try: |
| 155 | + kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py") |
| 156 | + |
| 157 | + if not os.path.exists(kernel_file): |
| 158 | + is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code |
| 159 | + if is_triton: |
| 160 | + full_code = self._prepare_triton_code(kernel_code) |
| 161 | + else: |
| 162 | + full_code = self._prepare_torch_code(kernel_code) |
| 163 | + |
| 164 | + with open(kernel_file, "w") as f: |
| 165 | + f.write(full_code) |
| 166 | + logger.debug(f"Saved kernel to: {kernel_file}") |
| 167 | + |
| 168 | + spec = importlib.util.spec_from_file_location( |
| 169 | + f"test_kernel_{op_name}_{attempt}", kernel_file |
| 170 | + ) |
| 171 | + module = importlib.util.module_from_spec(spec) |
| 172 | + |
| 173 | + # Add to sys.modules so triton can find it |
| 174 | + sys.modules[f"test_kernel_{op_name}_{attempt}"] = module |
| 175 | + |
| 176 | + try: |
| 177 | + spec.loader.exec_module(module) |
| 178 | + |
| 179 | + expected_name = f"{op_name}_kernel_impl" |
| 180 | + if hasattr(module, expected_name): |
| 181 | + compiled_kernel = getattr(module, expected_name) |
| 182 | + else: |
| 183 | + available_functions = [ |
| 184 | + name |
| 185 | + for name in dir(module) |
| 186 | + if callable(getattr(module, name)) and not name.startswith("_") |
| 187 | + ] |
| 188 | + raise ValueError( |
| 189 | + f"Expected function '{expected_name}' not found. Available: {available_functions}" |
| 190 | + ) |
| 191 | + |
| 192 | + finally: |
| 193 | + if f"test_kernel_{op_name}_{attempt}" in sys.modules: |
| 194 | + del sys.modules[f"test_kernel_{op_name}_{attempt}"] |
| 195 | + |
| 196 | + # Clear CUDA cache and synchronize to prevent memory buildup |
| 197 | + if torch.cuda.is_available(): |
| 198 | + torch.cuda.empty_cache() |
| 199 | + torch.cuda.synchronize() |
| 200 | + |
| 201 | + correct_count = 0 |
| 202 | + total_count = 0 |
| 203 | + |
| 204 | + for test in test_cases: |
| 205 | + try: |
| 206 | + args = test.args |
| 207 | + kwargs = test.kwargs |
| 208 | + |
| 209 | + ref_result = op(*args, **kwargs) |
| 210 | + |
| 211 | + # Clear CUDA cache after running each kernel to prevent grabbing previous solutions |
| 212 | + if torch.cuda.is_available(): |
| 213 | + torch.cuda.empty_cache() |
| 214 | + |
| 215 | + kernel_result = compiled_kernel(*args, **kwargs) |
| 216 | + |
| 217 | + torch.testing.assert_close(ref_result, kernel_result, equal_nan=True) |
| 218 | + correct_count += 1 |
| 219 | + logger.debug(f" ✓ Test passed: {ref_result.shape} {ref_result.dtype}") |
| 220 | + |
| 221 | + except Exception as e: |
| 222 | + logger.debug(f" ✗ Test failed: {str(e)}") |
| 223 | + |
| 224 | + feedback_info["test_errors"].append( |
| 225 | + { |
| 226 | + "test_input": f"args={[arg.shape if hasattr(arg, 'shape') else arg for arg in args]}, kwargs={kwargs}", |
| 227 | + "error": str(e), |
| 228 | + "error_type": type(e).__name__, |
| 229 | + "traceback": traceback.format_exc(), |
| 230 | + } |
| 231 | + ) |
| 232 | + |
| 233 | + finally: |
| 234 | + # Clean up memory by deleting args and kwargs if they exist |
| 235 | + if "args" in locals(): |
| 236 | + del args |
| 237 | + if "kwargs" in locals(): |
| 238 | + del kwargs |
| 239 | + |
| 240 | + total_count += 1 |
| 241 | + |
| 242 | + is_correct = correct_count == total_count and total_count > 0 |
| 243 | + if not is_correct: |
| 244 | + feedback_info["summary"] = f"{correct_count}/{total_count} tests passed" |
| 245 | + |
| 246 | + return is_correct, feedback_info |
| 247 | + |
| 248 | + except Exception as e: |
| 249 | + logger.error(" ✗ Compilation failed:") |
| 250 | + logger.error(f" Error: {str(e)}") |
| 251 | + |
| 252 | + feedback_info["compilation_error"] = str(e) |
| 253 | + feedback_info["summary"] = "Compilation failed" |
| 254 | + return False, feedback_info |
| 255 | + |
| 256 | + def __getitem__(self, key): |
| 257 | + if key in self.compiled_kernels: |
| 258 | + return self.compiled_kernels[key] |
| 259 | + raise KeyError(f"No kernel implementation found for {key}") |
| 260 | + |
| 261 | + def __contains__(self, key): |
| 262 | + return key in self.compiled_kernels |
0 commit comments