Skip to content

Commit 7dd105f

Browse files
authored
LLM Backend
1 parent a3f8d0f commit 7dd105f

File tree

9 files changed

+686
-5
lines changed

9 files changed

+686
-5
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ __pycache__/
22
.claude/
33
.vscode/
44
.ruff_cache/
5+
generated_kernels/

BackendBench/backends.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import os
2+
import importlib.util
3+
from typing import Dict, Callable, List
4+
5+
16
class Backend:
27
def __init__(self, name):
38
self.name = name
@@ -278,3 +283,229 @@ def __getitem__(self, key):
278283

279284
def __contains__(self, key):
280285
return key in self.ops
286+
287+
288+
class LLMBackend(Backend):
289+
def __init__(self) -> None:
290+
super().__init__("llm")
291+
self.compiled_kernels: Dict[str, Callable] = {}
292+
293+
# Create generated_kernels directory
294+
import datetime
295+
296+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
297+
self.kernels_dir = f"generated_kernels/run_{timestamp}"
298+
os.makedirs(self.kernels_dir, exist_ok=True)
299+
300+
# Create README for this run
301+
readme_path = os.path.join(self.kernels_dir, "README.md")
302+
with open(readme_path, "w") as f:
303+
f.write(f"""# Generated Kernels - {timestamp}
304+
305+
This directory contains PyTorch/Triton kernels generated by the LLM Backend.
306+
307+
## Run Info
308+
- Timestamp: {timestamp}
309+
- Backend: LLM
310+
311+
## Files
312+
Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation, including:
313+
- All necessary imports
314+
- Triton kernel implementation (if applicable)
315+
- Wrapper function that matches PyTorch operation signature
316+
317+
## Usage
318+
You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced.
319+
""")
320+
321+
print(f"Saving generated kernels to: {self.kernels_dir}")
322+
323+
def compile_kernel_from_string(
324+
self, kernel_code: str, op_name: str, attempt: int = 1
325+
) -> Callable:
326+
"""Compile a kernel from string code and return a callable."""
327+
try:
328+
is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code
329+
330+
if is_triton:
331+
full_code = self._prepare_triton_code(kernel_code)
332+
else:
333+
full_code = self._prepare_torch_code(kernel_code)
334+
335+
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py")
336+
with open(kernel_file, "w") as f:
337+
f.write(full_code)
338+
339+
print(f"Saved kernel to: {kernel_file}")
340+
341+
spec = importlib.util.spec_from_file_location(f"kernel_{op_name}", kernel_file)
342+
module = importlib.util.module_from_spec(spec)
343+
spec.loader.exec_module(module)
344+
345+
kernel_func = self._find_kernel_function(module, op_name)
346+
347+
return kernel_func
348+
349+
except Exception as e:
350+
raise RuntimeError(f"Failed to compile kernel for {op_name}: {str(e)}")
351+
352+
def _prepare_triton_code(self, kernel_code: str) -> str:
353+
"""Prepare Triton kernel code with necessary imports."""
354+
imports = """
355+
import torch
356+
import triton
357+
import triton.language as tl
358+
"""
359+
if "import torch" not in kernel_code:
360+
kernel_code = imports + kernel_code
361+
return kernel_code
362+
363+
def _prepare_torch_code(self, kernel_code: str) -> str:
364+
"""Prepare regular PyTorch kernel code with necessary imports."""
365+
imports = """
366+
import torch
367+
import torch.nn.functional as F
368+
"""
369+
if "import torch" not in kernel_code:
370+
kernel_code = imports + kernel_code
371+
return kernel_code
372+
373+
def _find_kernel_function(self, module, op_name: str) -> Callable:
374+
"""Find the main kernel function in the compiled module."""
375+
expected_name = f"{op_name}_kernel_impl"
376+
377+
if hasattr(module, expected_name):
378+
return getattr(module, expected_name)
379+
380+
available_functions = [
381+
name
382+
for name in dir(module)
383+
if callable(getattr(module, name)) and not name.startswith("_")
384+
]
385+
386+
raise ValueError(
387+
f"Expected function '{expected_name}' not found in kernel code for {op_name}. "
388+
f"Available functions: {available_functions}. "
389+
f"Please ensure the LLM generated code follows the naming convention: {op_name}_kernel_impl"
390+
)
391+
392+
def add_kernel(self, op, kernel_code: str, op_name: str):
393+
"""Add a kernel implementation for a specific operator."""
394+
compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1)
395+
self.compiled_kernels[op] = compiled_kernel
396+
397+
def test_kernel_correctness(
398+
self, op, kernel_code: str, test_cases: List, attempt: int = 1
399+
) -> tuple[bool, Dict]:
400+
"""Test kernel correctness and return detailed feedback."""
401+
op_str = str(op)
402+
if "aten." in op_str:
403+
op_name = op_str.split("aten.")[-1].split(".")[0]
404+
else:
405+
op_name = op_str.split(".")[-1]
406+
407+
feedback_info = {
408+
"compilation_error": None,
409+
"test_errors": [],
410+
"summary": None,
411+
}
412+
413+
try:
414+
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel_attempt_{attempt}.py")
415+
416+
if not os.path.exists(kernel_file):
417+
is_triton = "triton.jit" in kernel_code or "@triton.jit" in kernel_code
418+
if is_triton:
419+
full_code = self._prepare_triton_code(kernel_code)
420+
else:
421+
full_code = self._prepare_torch_code(kernel_code)
422+
423+
with open(kernel_file, "w") as f:
424+
f.write(full_code)
425+
print(f"Saved kernel to: {kernel_file}")
426+
427+
import sys
428+
import importlib.util
429+
430+
spec = importlib.util.spec_from_file_location(
431+
f"test_kernel_{op_name}_{attempt}", kernel_file
432+
)
433+
module = importlib.util.module_from_spec(spec)
434+
435+
# Add to sys.modules so triton can find it
436+
sys.modules[f"test_kernel_{op_name}_{attempt}"] = module
437+
438+
try:
439+
spec.loader.exec_module(module)
440+
441+
expected_name = f"{op_name}_kernel_impl"
442+
if hasattr(module, expected_name):
443+
compiled_kernel = getattr(module, expected_name)
444+
else:
445+
available_functions = [
446+
name
447+
for name in dir(module)
448+
if callable(getattr(module, name)) and not name.startswith("_")
449+
]
450+
raise ValueError(
451+
f"Expected function '{expected_name}' not found. Available: {available_functions}"
452+
)
453+
454+
finally:
455+
if f"test_kernel_{op_name}_{attempt}" in sys.modules:
456+
del sys.modules[f"test_kernel_{op_name}_{attempt}"]
457+
458+
import torch
459+
460+
correct_count = 0
461+
total_count = 0
462+
463+
for test in test_cases:
464+
try:
465+
args = test.args
466+
kwargs = test.kwargs
467+
468+
ref_result = op(*args, **kwargs)
469+
kernel_result = compiled_kernel(*args, **kwargs)
470+
471+
torch.testing.assert_close(ref_result, kernel_result, equal_nan=True)
472+
correct_count += 1
473+
print(f" ✓ Test passed: {ref_result.shape} {ref_result.dtype}")
474+
475+
except Exception as e:
476+
import traceback
477+
478+
print(f" ✗ Test failed: {str(e)}")
479+
480+
feedback_info["test_errors"].append(
481+
{
482+
"test_input": f"args={[arg.shape if hasattr(arg, 'shape') else arg for arg in args]}, kwargs={kwargs}",
483+
"error": str(e),
484+
"error_type": type(e).__name__,
485+
"traceback": traceback.format_exc(),
486+
}
487+
)
488+
489+
total_count += 1
490+
491+
is_correct = correct_count == total_count and total_count > 0
492+
if not is_correct:
493+
feedback_info["summary"] = f"{correct_count}/{total_count} tests passed"
494+
495+
return is_correct, feedback_info
496+
497+
except Exception as e:
498+
print(" ✗ Compilation failed:")
499+
print(f" Error: {str(e)}")
500+
501+
feedback_info["compilation_error"] = str(e)
502+
feedback_info["summary"] = "Compilation failed"
503+
return False, feedback_info
504+
505+
def __getitem__(self, key):
506+
if key in self.compiled_kernels:
507+
return self.compiled_kernels[key]
508+
raise KeyError(f"No kernel implementation found for {key}")
509+
510+
def __contains__(self, key):
511+
return key in self.compiled_kernels

BackendBench/eval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import logging
22

33
import torch
4+
45
from triton.testing import do_bench
56

7+
68
logger = logging.getLogger(__name__)
79

810

@@ -66,7 +68,6 @@ def eval_performance(op, impl, tests):
6668
test_times = [cpu_bench(lambda: impl(*test.args, **test.kwargs)) for test in tests]
6769

6870
speedups = torch.tensor(test_times) / torch.tensor(base_times)
69-
# geometric mean of speedups
7071
return speedups.log().mean().exp()
7172

7273

BackendBench/kernel_templates.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
"""
2+
Kernel code templates and prompt engineering for LLM-based kernel generation.
3+
"""
4+
5+
from typing import Dict
6+
from .prompts import (
7+
TRITON_KERNEL_PROMPT,
8+
PYTORCH_KERNEL_PROMPT,
9+
TRITON_OPTIMIZATIONS,
10+
TRITON_EXAMPLE_TEMPLATES,
11+
)
12+
13+
14+
class KernelTemplate:
15+
"""Base class for kernel templates."""
16+
17+
def __init__(self, name: str, framework: str):
18+
self.name = name
19+
self.framework = framework
20+
21+
def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
22+
"""Create a prompt for kernel generation."""
23+
raise NotImplementedError
24+
25+
26+
class TritonKernelTemplate(KernelTemplate):
27+
"""Template for Triton kernel generation."""
28+
29+
def __init__(self):
30+
super().__init__("triton", "triton")
31+
32+
def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
33+
"""Create a specialized prompt for Triton kernel generation."""
34+
35+
# Get operation-specific optimizations
36+
optimizations = self._get_optimizations(op_name)
37+
38+
# Get example template
39+
example = self._get_example_template(op_name)
40+
41+
return TRITON_KERNEL_PROMPT.format(
42+
op_name=op_name,
43+
op_signature=op_signature,
44+
op_description=op_description,
45+
optimizations=optimizations,
46+
example=example,
47+
)
48+
49+
def _get_optimizations(self, op_name: str) -> str:
50+
"""Get operation-specific optimization guidelines."""
51+
return TRITON_OPTIMIZATIONS.get(op_name, TRITON_OPTIMIZATIONS["default"])
52+
53+
def _get_example_template(self, op_name: str) -> str:
54+
"""Get operation-specific code template."""
55+
return TRITON_EXAMPLE_TEMPLATES["default"]
56+
57+
58+
class PyTorchKernelTemplate(KernelTemplate):
59+
"""Template for pure PyTorch kernel generation."""
60+
61+
def __init__(self):
62+
super().__init__("pytorch", "pytorch")
63+
64+
def create_prompt(self, op_name: str, op_signature: str, op_description: str) -> str:
65+
"""Create a prompt for PyTorch kernel generation."""
66+
67+
return PYTORCH_KERNEL_PROMPT.format(
68+
op_name=op_name, op_signature=op_signature, op_description=op_description
69+
)
70+
71+
72+
class KernelTemplateManager:
73+
"""Manages kernel templates for different frameworks."""
74+
75+
def __init__(self):
76+
self.templates: Dict[str, KernelTemplate] = {
77+
"triton": TritonKernelTemplate(),
78+
"pytorch": PyTorchKernelTemplate(),
79+
# TODO: Add cuda, cutile, whatever we want
80+
}
81+
82+
def get_template(self, framework: str) -> KernelTemplate:
83+
"""Get template for specified framework."""
84+
if framework not in self.templates:
85+
raise ValueError(f"Unknown framework: {framework}")
86+
return self.templates[framework]
87+
88+
def create_prompt(
89+
self, op_name: str, op_signature: str, op_description: str, framework: str = "triton"
90+
) -> str:
91+
"""Create a prompt using the specified template."""
92+
template = self.get_template(framework)
93+
return template.create_prompt(op_name, op_signature, op_description)
94+
95+
def create_refinement_prompt(
96+
self,
97+
op_name: str,
98+
op_signature: str,
99+
op_description: str,
100+
framework: str = "triton",
101+
feedback: str = "",
102+
) -> str:
103+
"""Create a refinement prompt with feedback from previous attempts."""
104+
base_prompt = self.create_prompt(op_name, op_signature, op_description, framework)
105+
106+
if feedback and feedback.strip():
107+
refinement_prompt = f"""{feedback}
108+
109+
{base_prompt}
110+
111+
Fix the above errors and generate corrected code."""
112+
else:
113+
# Fallback if no feedback
114+
refinement_prompt = f"""{base_prompt}
115+
116+
The previous attempt failed. Please generate a corrected version."""
117+
118+
return refinement_prompt

0 commit comments

Comments
 (0)