Skip to content

Commit 95e1a03

Browse files
authored
Add LLM Relay Backend for local plugboard server integration for model evals (#64)
* Add LLMRelayBackend class in backends/llm_relay.py * Update backends/init.py to include LLMRelayBackend * Add llm-relay backend support to main.py with setup_llm_relay_backend function * Add --llm-relay-model CLI option for model configuration * Update LLMKernelGenerator to support configurable model parameter
1 parent cd66e79 commit 95e1a03

File tree

4 files changed

+597
-9
lines changed

4 files changed

+597
-9
lines changed

BackendBench/backends/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from .flag_gems import FlagGemsBackend
1313
from .kernel_agent import KernelAgentBackend
1414
from .llm import LLMBackend
15+
from .llm_relay import LLMRelayBackend
1516

1617
__all__ = [
1718
"Backend",
1819
"DirectoryBackend",
1920
"AtenBackend",
2021
"FlagGemsBackend",
2122
"LLMBackend",
23+
"LLMRelayBackend",
2224
"KernelAgentBackend",
2325
]

BackendBench/backends/llm_relay.py

Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
import datetime
2+
import importlib.util
3+
import logging
4+
import os
5+
import sys
6+
import torch
7+
import traceback
8+
from typing import Callable, Dict, List
9+
10+
from .base import Backend
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
class LLMRelayBackend(Backend):
16+
"""
17+
Backend that uses LLMKernelGenerator to communicate with local plugboard server.
18+
This backend will eventually replace the LLMBackend that uses direct Anthropic API calls.
19+
"""
20+
21+
def __init__(self, model: str = "gcp-claude-4-sonnet") -> None:
22+
super().__init__("llm-relay")
23+
self.compiled_kernels: Dict[str, Callable] = {}
24+
self.model = model
25+
26+
# Create generated_kernels directory
27+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
28+
self.kernels_dir = f"generated_kernels/llm_relay_run_{timestamp}"
29+
os.makedirs(self.kernels_dir, exist_ok=True)
30+
31+
# Create README for this run
32+
readme_path = os.path.join(self.kernels_dir, "README.md")
33+
with open(readme_path, "w") as f:
34+
f.write(
35+
f"""# Generated Kernels - LLM Relay - {timestamp}
36+
37+
This directory contains PyTorch/Triton kernels generated by the LLM Relay Backend.
38+
39+
## Run Info
40+
- Timestamp: {timestamp}
41+
- Backend: LLM Relay
42+
- Model: {model}
43+
- Server: Local plugboard server (localhost:11434)
44+
45+
## Files
46+
Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation, including:
47+
- All necessary imports
48+
- Triton kernel implementation (if applicable)
49+
- Wrapper function that matches PyTorch operation signature
50+
51+
## Server Setup
52+
This backend requires the plugboard server to be running:
53+
```
54+
buck run @//mode/inplace run_plugboard_server -- --model gcp-claude-4-sonnet --pipeline usecase-dev-ai-user
55+
```
56+
57+
## Usage
58+
You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced.
59+
"""
60+
)
61+
62+
logger.info(f"Saving LLM Relay generated kernels to: {self.kernels_dir}")
63+
64+
def compile_kernel_from_string(
65+
self, kernel_code: str, op_name: str, attempt: int = 1
66+
) -> Callable:
67+
"""Compile a kernel from string code and return a callable."""
68+
try:
69+
is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code
70+
71+
if is_triton:
72+
full_code = self._prepare_triton_code(kernel_code)
73+
else:
74+
full_code = self._prepare_torch_code(kernel_code)
75+
76+
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py")
77+
with open(kernel_file, "w") as f:
78+
f.write(full_code)
79+
80+
logger.debug(f"Saved kernel to: {kernel_file}")
81+
82+
spec = importlib.util.spec_from_file_location(f"kernel_{op_name}", kernel_file)
83+
module = importlib.util.module_from_spec(spec)
84+
spec.loader.exec_module(module)
85+
86+
kernel_func = self._find_kernel_function(module, op_name)
87+
88+
return kernel_func
89+
90+
except Exception as e:
91+
raise RuntimeError(f"Failed to compile kernel for {op_name}: {str(e)}")
92+
93+
def _prepare_triton_code(self, kernel_code: str) -> str:
94+
"""Prepare Triton kernel code with necessary imports."""
95+
imports = """
96+
import torch
97+
import triton
98+
import triton.language as tl
99+
"""
100+
if "import torch" not in kernel_code:
101+
kernel_code = imports + kernel_code
102+
return kernel_code
103+
104+
def _prepare_torch_code(self, kernel_code: str) -> str:
105+
"""Prepare regular PyTorch kernel code with necessary imports."""
106+
imports = """
107+
import torch
108+
import torch.nn.functional as F
109+
"""
110+
if "import torch" not in kernel_code:
111+
kernel_code = imports + kernel_code
112+
return kernel_code
113+
114+
def _find_kernel_function(self, module, op_name: str) -> Callable:
115+
"""Find the main kernel function in the compiled module."""
116+
expected_name = f"{op_name}_kernel_impl"
117+
118+
if hasattr(module, expected_name):
119+
return getattr(module, expected_name)
120+
121+
available_functions = [
122+
name
123+
for name in dir(module)
124+
if callable(getattr(module, name)) and not name.startswith("_")
125+
]
126+
127+
raise ValueError(
128+
f"Expected function '{expected_name}' not found in kernel code for {op_name}. "
129+
f"Available functions: {available_functions}. "
130+
f"Please ensure the LLM generated code follows the naming convention: {op_name}_kernel_impl"
131+
)
132+
133+
def add_kernel(self, op, kernel_code: str, op_name: str):
134+
"""Add a kernel implementation for a specific operator."""
135+
compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1)
136+
self.compiled_kernels[op] = compiled_kernel
137+
138+
def test_kernel_correctness(
139+
self, op, kernel_code: str, test_cases: List, attempt: int = 1
140+
) -> tuple[bool, Dict]:
141+
"""Test kernel correctness and return detailed feedback."""
142+
op_str = str(op)
143+
if "aten." in op_str:
144+
op_name = op_str.split("aten.")[-1].split(".")[0]
145+
else:
146+
op_name = op_str.split(".")[-1]
147+
148+
feedback_info = {
149+
"compilation_error": None,
150+
"test_errors": [],
151+
"summary": None,
152+
}
153+
154+
try:
155+
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py")
156+
157+
if not os.path.exists(kernel_file):
158+
is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code
159+
if is_triton:
160+
full_code = self._prepare_triton_code(kernel_code)
161+
else:
162+
full_code = self._prepare_torch_code(kernel_code)
163+
164+
with open(kernel_file, "w") as f:
165+
f.write(full_code)
166+
logger.debug(f"Saved kernel to: {kernel_file}")
167+
168+
spec = importlib.util.spec_from_file_location(
169+
f"test_kernel_{op_name}_{attempt}", kernel_file
170+
)
171+
module = importlib.util.module_from_spec(spec)
172+
173+
# Add to sys.modules so triton can find it
174+
sys.modules[f"test_kernel_{op_name}_{attempt}"] = module
175+
176+
try:
177+
spec.loader.exec_module(module)
178+
179+
expected_name = f"{op_name}_kernel_impl"
180+
if hasattr(module, expected_name):
181+
compiled_kernel = getattr(module, expected_name)
182+
else:
183+
available_functions = [
184+
name
185+
for name in dir(module)
186+
if callable(getattr(module, name)) and not name.startswith("_")
187+
]
188+
raise ValueError(
189+
f"Expected function '{expected_name}' not found. Available: {available_functions}"
190+
)
191+
192+
finally:
193+
if f"test_kernel_{op_name}_{attempt}" in sys.modules:
194+
del sys.modules[f"test_kernel_{op_name}_{attempt}"]
195+
196+
# Clear CUDA cache and synchronize to prevent memory buildup
197+
if torch.cuda.is_available():
198+
torch.cuda.empty_cache()
199+
torch.cuda.synchronize()
200+
201+
correct_count = 0
202+
total_count = 0
203+
204+
for test in test_cases:
205+
try:
206+
args = test.args
207+
kwargs = test.kwargs
208+
209+
ref_result = op(*args, **kwargs)
210+
211+
# Clear CUDA cache after running each kernel to prevent grabbing previous solutions
212+
if torch.cuda.is_available():
213+
torch.cuda.empty_cache()
214+
215+
kernel_result = compiled_kernel(*args, **kwargs)
216+
217+
torch.testing.assert_close(ref_result, kernel_result, equal_nan=True)
218+
correct_count += 1
219+
logger.debug(f" ✓ Test passed: {ref_result.shape} {ref_result.dtype}")
220+
221+
except Exception as e:
222+
logger.debug(f" ✗ Test failed: {str(e)}")
223+
224+
feedback_info["test_errors"].append(
225+
{
226+
"test_input": f"args={[arg.shape if hasattr(arg, 'shape') else arg for arg in args]}, kwargs={kwargs}",
227+
"error": str(e),
228+
"error_type": type(e).__name__,
229+
"traceback": traceback.format_exc(),
230+
}
231+
)
232+
233+
finally:
234+
# Clean up memory by deleting args and kwargs if they exist
235+
if "args" in locals():
236+
del args
237+
if "kwargs" in locals():
238+
del kwargs
239+
240+
total_count += 1
241+
242+
is_correct = correct_count == total_count and total_count > 0
243+
if not is_correct:
244+
feedback_info["summary"] = f"{correct_count}/{total_count} tests passed"
245+
246+
return is_correct, feedback_info
247+
248+
except Exception as e:
249+
logger.error(" ✗ Compilation failed:")
250+
logger.error(f" Error: {str(e)}")
251+
252+
feedback_info["compilation_error"] = str(e)
253+
feedback_info["summary"] = "Compilation failed"
254+
return False, feedback_info
255+
256+
def __getitem__(self, key):
257+
if key in self.compiled_kernels:
258+
return self.compiled_kernels[key]
259+
raise KeyError(f"No kernel implementation found for {key}")
260+
261+
def __contains__(self, key):
262+
return key in self.compiled_kernels

0 commit comments

Comments
 (0)