Skip to content

Commit ba6128d

Browse files
authored
Add KernelAgent backend integration (#22)
1 parent 4717e58 commit ba6128d

File tree

5 files changed

+518
-2
lines changed

5 files changed

+518
-2
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "KernelAgent"]
2+
path = KernelAgent
3+
url = [email protected]:pytorch-labs/KernelAgent.git

BackendBench/backends.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,314 @@ def __getitem__(self, key):
524524

525525
def __contains__(self, key):
526526
return key in self.compiled_kernels
527+
528+
529+
class KernelAgentBackend(Backend):
530+
"""
531+
Backend that uses KernelAgent for sophisticated parallel kernel generation.
532+
533+
This backend leverages KernelAgent's advanced features:
534+
- Parallel workers with iterative refinement
535+
- Multi-turn conversation history
536+
- Comprehensive prompt engineering with Triton guidelines
537+
- Automatic test generation
538+
"""
539+
540+
def __init__(self) -> None:
541+
super().__init__("kernel_agent")
542+
self.compiled_kernels: Dict[str, Callable] = {}
543+
544+
# Create generated_kernels directory
545+
import datetime
546+
547+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
548+
self.kernels_dir = f"generated_kernels/kernel_agent_run_{timestamp}"
549+
os.makedirs(self.kernels_dir, exist_ok=True)
550+
551+
# Create README for this run
552+
readme_path = os.path.join(self.kernels_dir, "README.md")
553+
with open(readme_path, "w") as f:
554+
f.write(f"""# Generated Kernels - KernelAgent - {timestamp}
555+
556+
This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend.
557+
558+
## Run Info
559+
- Timestamp: {timestamp}
560+
- Backend: KernelAgent
561+
- Features: Parallel workers, iterative refinement, conversation history
562+
563+
## Files
564+
Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation.
565+
KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts.
566+
567+
## KernelAgent Features Used
568+
- Parallel workers for increased success rate
569+
- Iterative refinement with multi-turn dialogue
570+
- Comprehensive Triton programming guidelines
571+
- Automatic test generation and validation
572+
- Session logging and artifact preservation
573+
574+
## Usage
575+
You can inspect these files to debug kernel generation, analyze the parallel worker outputs,
576+
or understand the sophisticated generation process used by KernelAgent.
577+
""")
578+
579+
print(f"Saving KernelAgent generated kernels to: {self.kernels_dir}")
580+
581+
# Initialize KernelAgent (imported lazily to avoid dependency issues)
582+
self.kernel_agent = None
583+
self.num_workers = 4 # Default values, can be overridden
584+
self.max_rounds = 10
585+
586+
def set_config(self, num_workers: int, max_rounds: int):
587+
"""Set configuration for KernelAgent."""
588+
self.num_workers = num_workers
589+
self.max_rounds = max_rounds
590+
591+
def _get_kernel_agent(self):
592+
"""Lazy initialization of KernelAgent to avoid import issues."""
593+
if self.kernel_agent is None:
594+
try:
595+
# Import KernelAgent from the submodule
596+
import sys
597+
598+
kernel_agent_path = os.path.join(os.path.dirname(__file__), "..", "KernelAgent")
599+
if kernel_agent_path not in sys.path:
600+
sys.path.insert(0, os.path.abspath(kernel_agent_path))
601+
602+
from triton_kernel_agent import TritonKernelAgent
603+
604+
# Create KernelAgent with custom log directory
605+
agent_log_dir = os.path.join(self.kernels_dir, "agent_logs")
606+
os.makedirs(agent_log_dir, exist_ok=True)
607+
608+
self.kernel_agent = TritonKernelAgent(
609+
log_dir=agent_log_dir, num_workers=self.num_workers, max_rounds=self.max_rounds
610+
)
611+
612+
print(f"✓ KernelAgent initialized with log directory: {agent_log_dir}")
613+
614+
except ImportError as e:
615+
raise ImportError(
616+
f"Failed to import KernelAgent: {e}\n"
617+
f"Please ensure KernelAgent submodule is properly initialized.\n"
618+
f"Run: git submodule update --init --recursive"
619+
)
620+
621+
return self.kernel_agent
622+
623+
def _create_problem_description_from_op(self, op, op_name: str) -> str:
624+
"""
625+
Create a problem description for KernelAgent based on the PyTorch operation.
626+
627+
Args:
628+
op: PyTorch operation
629+
op_name: Operation name extracted from op
630+
631+
Returns:
632+
Problem description string for KernelAgent
633+
"""
634+
# Create a comprehensive problem description that KernelAgent can understand
635+
problem_description = f"""
636+
Implement a high-performance Triton kernel for the PyTorch operation: {op_name}
637+
638+
Operation details:
639+
- PyTorch operation: {op}
640+
- Operation name: {op_name}
641+
- Framework target: OpenAI Triton
642+
643+
Requirements:
644+
1. The kernel must be functionally equivalent to the PyTorch operation
645+
2. Implement using Triton language primitives (tl.load, tl.store, etc.)
646+
3. Handle all tensor shapes and data types that the original operation supports
647+
4. Optimize for GPU performance with proper memory coalescing
648+
5. Include proper boundary condition handling
649+
6. Follow Triton best practices for kernel design
650+
651+
The generated kernel should:
652+
- Take the same input arguments as the PyTorch operation
653+
- Return outputs with identical shapes, dtypes, and numerical values
654+
- Be optimized for common tensor shapes and memory layouts
655+
- Handle edge cases gracefully
656+
657+
Please generate a complete, production-ready Triton kernel implementation.
658+
"""
659+
return problem_description
660+
661+
def _adapt_kernel_function_name(self, kernel_code: str, op_name: str) -> str:
662+
"""
663+
Adapt KernelAgent's 'kernel_function' to BackendBench's expected naming convention.
664+
665+
KernelAgent generates kernels with 'kernel_function' as the main entry point.
666+
BackendBench expects '{op_name}_kernel_impl' as the function name.
667+
668+
Args:
669+
kernel_code: Original kernel code from KernelAgent
670+
op_name: Operation name for the expected function name
671+
672+
Returns:
673+
Modified kernel code with correct function name
674+
"""
675+
expected_name = f"{op_name}_kernel_impl"
676+
677+
# Replace 'def kernel_function' with 'def {op_name}_kernel_impl'
678+
if "def kernel_function(" in kernel_code:
679+
adapted_code = kernel_code.replace("def kernel_function(", f"def {expected_name}(")
680+
681+
# Also replace any docstring references
682+
adapted_code = adapted_code.replace(
683+
'"""Wrapper function that handles kernel launch."""',
684+
f'"""{op_name} kernel implementation using Triton."""',
685+
)
686+
687+
return adapted_code
688+
else:
689+
# If kernel_function is not found, add a wrapper that calls the existing function
690+
wrapper_code = f'''
691+
692+
def {expected_name}(*args, **kwargs):
693+
"""{op_name} kernel implementation using Triton - BackendBench adapter."""
694+
# Call the original kernel_function from KernelAgent
695+
return kernel_function(*args, **kwargs)
696+
'''
697+
return kernel_code + wrapper_code
698+
699+
def compile_kernel_from_string(
700+
self, kernel_code: str, op_name: str, attempt: int = 1
701+
) -> Callable:
702+
"""Compile a kernel from string code and return a callable."""
703+
try:
704+
# Adapt the function name for BackendBench compatibility
705+
adapted_code = self._adapt_kernel_function_name(kernel_code, op_name)
706+
707+
# Prepare the code with necessary imports
708+
is_triton = "triton.jit" in adapted_code or "@triton.jit" in adapted_code
709+
if is_triton:
710+
full_code = self._prepare_triton_code(adapted_code)
711+
else:
712+
full_code = self._prepare_torch_code(adapted_code)
713+
714+
# Save the kernel to file
715+
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel.py")
716+
with open(kernel_file, "w") as f:
717+
f.write(full_code)
718+
719+
print(f"Saved KernelAgent kernel to: {kernel_file}")
720+
721+
# Import and compile the kernel
722+
spec = importlib.util.spec_from_file_location(f"kernel_agent_{op_name}", kernel_file)
723+
module = importlib.util.module_from_spec(spec)
724+
spec.loader.exec_module(module)
725+
726+
# Find the expected function
727+
expected_name = f"{op_name}_kernel_impl"
728+
if hasattr(module, expected_name):
729+
return getattr(module, expected_name)
730+
else:
731+
available_functions = [
732+
name
733+
for name in dir(module)
734+
if callable(getattr(module, name)) and not name.startswith("_")
735+
]
736+
raise ValueError(
737+
f"Expected function '{expected_name}' not found in KernelAgent kernel. "
738+
f"Available: {available_functions}"
739+
)
740+
741+
except Exception as e:
742+
raise RuntimeError(f"Failed to compile KernelAgent kernel for {op_name}: {str(e)}")
743+
744+
def _prepare_triton_code(self, kernel_code: str) -> str:
745+
"""Prepare Triton kernel code with necessary imports."""
746+
imports = """
747+
import torch
748+
import triton
749+
import triton.language as tl
750+
"""
751+
if "import torch" not in kernel_code:
752+
kernel_code = imports + kernel_code
753+
return kernel_code
754+
755+
def _prepare_torch_code(self, kernel_code: str) -> str:
756+
"""Prepare regular PyTorch kernel code with necessary imports."""
757+
imports = """
758+
import torch
759+
import torch.nn.functional as F
760+
"""
761+
if "import torch" not in kernel_code:
762+
kernel_code = imports + kernel_code
763+
return kernel_code
764+
765+
def add_kernel(self, op, kernel_code: str, op_name: str):
766+
"""Add a kernel implementation for a specific operator."""
767+
compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1)
768+
self.compiled_kernels[op] = compiled_kernel
769+
770+
# Save the original KernelAgent code as well
771+
original_file = os.path.join(self.kernels_dir, f"{op_name}_original_kernel_agent.py")
772+
with open(original_file, "w") as f:
773+
f.write(kernel_code)
774+
775+
def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
776+
"""
777+
Generate a kernel using KernelAgent's sophisticated generation system.
778+
779+
Args:
780+
op: PyTorch operation
781+
op_name: Operation name
782+
783+
Returns:
784+
tuple: (kernel_code, success)
785+
"""
786+
try:
787+
agent = self._get_kernel_agent()
788+
789+
# Create problem description
790+
problem_description = self._create_problem_description_from_op(op, op_name)
791+
792+
print(
793+
f"🚀 Generating {op_name} kernel with KernelAgent (parallel workers + refinement)"
794+
)
795+
796+
# Generate kernel using KernelAgent
797+
result = agent.generate_kernel(
798+
problem_description=problem_description,
799+
test_code=None, # Let KernelAgent auto-generate the test
800+
)
801+
802+
if result["success"]:
803+
print(f"✅ KernelAgent succeeded for {op_name}!")
804+
print(
805+
f" Worker {result['worker_id']} found solution in {result['rounds']} rounds"
806+
)
807+
print(f" Session: {result['session_dir']}")
808+
809+
# Copy the session directory to our kernels directory for preservation
810+
import shutil
811+
812+
session_name = os.path.basename(result["session_dir"])
813+
preserved_session = os.path.join(
814+
self.kernels_dir, f"{op_name}_session_{session_name}"
815+
)
816+
try:
817+
shutil.copytree(result["session_dir"], preserved_session)
818+
print(f" Session preserved: {preserved_session}")
819+
except Exception as e:
820+
print(f" Warning: Could not preserve session: {e}")
821+
822+
return result["kernel_code"], True
823+
else:
824+
print(f"❌ KernelAgent failed for {op_name}: {result['message']}")
825+
return "", False
826+
827+
except Exception as e:
828+
print(f"❌ KernelAgent error for {op_name}: {e}")
829+
return "", False
830+
831+
def __getitem__(self, key):
832+
if key in self.compiled_kernels:
833+
return self.compiled_kernels[key]
834+
raise KeyError(f"No KernelAgent kernel implementation found for {key}")
835+
836+
def __contains__(self, key):
837+
return key in self.compiled_kernels

KernelAgent

Submodule KernelAgent added at 2b26ae0

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,30 @@ Run LLM evaluation on smoke test (relu operation):
4646
export ANTHROPIC_API_KEY=your_api_key_here
4747
python scripts/main.py --suite smoke --backend llm
4848
```
49+
50+
## KernelAgent-Based Triton Kernel Generation
51+
52+
Generate and evaluate PyTorch kernels using KernelAgent's advanced system with parallel workers and iterative refinement:
53+
54+
**Prerequisites**: Initialize the KernelAgent submodule:
55+
```bash
56+
git submodule update --init --recursive
57+
```
58+
59+
Run KernelAgent evaluation on smoke test (relu operation):
60+
```bash
61+
export OPENAI_API_KEY=your_api_key_here
62+
python scripts/main.py --suite smoke --backend kernel_agent
63+
```
64+
65+
Run KernelAgent with custom configuration:
66+
```bash
67+
export OPENAI_API_KEY=your_api_key_here
68+
python scripts/main.py --suite smoke --backend kernel_agent --kernel-agent-workers 6 --kernel-agent-max-rounds 15
69+
```
70+
71+
Run KernelAgent on opinfo tests with a specific operation:
72+
```bash
73+
export OPENAI_API_KEY=your_api_key_here
74+
python scripts/main.py --suite opinfo --backend kernel_agent --ops "add"
75+
```

0 commit comments

Comments
 (0)