Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5b57a8c
feat: Integrate KernelAgent with BackendBench test cases
Laurawly Aug 22, 2025
ca194c8
fix: Change filter logic to use exact operation name matching
Laurawly Aug 22, 2025
8ba8bd3
feat: Add enhanced KernelAgent run scripts with result organization
Laurawly Aug 22, 2025
9084d45
style: Run ruff format on core_torchbench_ops.py
Laurawly Aug 22, 2025
58ed0a7
style: Run ruff format on kernel_agent.py and data_loaders.py
Laurawly Aug 22, 2025
3f5d5a6
chore: Add license header to core_torchbench_ops.py
Laurawly Aug 22, 2025
96f2727
refactor: Address PR reviews - use PR #90 directory structure
Laurawly Aug 23, 2025
d38491c
feat: Add score tracking to run_kernel_agent.py
Laurawly Aug 23, 2025
e7050c0
feat: Add FP16/BF16 filtering and Triton-friendly operation classific…
Laurawly Aug 23, 2025
ea45857
feat: Complete TorchBench operation categorization and KernelAgent in…
Laurawly Aug 24, 2025
5dafec7
Merge main branch: Add verbose mode and untestable operators
Laurawly Aug 24, 2025
061b57c
fix: Correct syntax error after merge
Laurawly Aug 24, 2025
a7fc8dc
feat: Use BackendBench serialization format for KernelAgent test gene…
Laurawly Aug 24, 2025
8af16e7
feat: Add KernelAgent-generated Triton kernels for 43 operations
Laurawly Sep 1, 2025
c400448
refactor: Remove README.md files from generated kernel folders
Laurawly Sep 1, 2025
bca3ec4
more ops added
Laurawly Sep 2, 2025
57a3e8a
Merge fix-kernelagent-tests: Resolve conflicts in hardswish and maxim…
Laurawly Sep 3, 2025
8cd28d8
Remove multi-version kernel implementations
Laurawly Sep 3, 2025
22783ad
feat: Add 26 new KernelAgent-generated Triton kernels
Laurawly Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
269 changes: 204 additions & 65 deletions BackendBench/backends/kernel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import datetime
import importlib.util
import logging
import os
from typing import Callable, Dict

from .base import Backend
from ..scripts.setup_operator_directories import clean_op_name_for_directory

logger = logging.getLogger(__name__)

Expand All @@ -29,44 +31,11 @@ def __init__(self) -> None:
super().__init__("kernel_agent")
self.compiled_kernels: Dict[str, Callable] = {}

# Create generated_kernels directory
import datetime

timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
self.kernels_dir = f"generated_kernels/kernel_agent_run_{timestamp}"
# Use PR #90 directory structure
self.kernels_dir = "generated_kernels"
os.makedirs(self.kernels_dir, exist_ok=True)

# Create README for this run
readme_path = os.path.join(self.kernels_dir, "README.md")
with open(readme_path, "w") as f:
f.write(
f"""# Generated Kernels - KernelAgent - {timestamp}

This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend.

## Run Info
- Timestamp: {timestamp}
- Backend: KernelAgent
- Features: Parallel workers, iterative refinement, conversation history

## Files
Each `<op_name>_kernel.py` file contains the complete generated kernel code for that operation.
KernelAgent session directories contain detailed logs, worker outputs, and generation artifacts.

## KernelAgent Features Used
- Parallel workers for increased success rate
- Iterative refinement with multi-turn dialogue
- Comprehensive Triton programming guidelines
- Automatic test generation and validation
- Session logging and artifact preservation

## Usage
You can inspect these files to debug kernel generation, analyze the parallel worker outputs,
or understand the sophisticated generation process used by KernelAgent.
"""
)

print(f"Saving KernelAgent generated kernels to: {self.kernels_dir}")
logger.info(f"Saving KernelAgent generated kernels to: {self.kernels_dir}")

# Initialize KernelAgent (imported lazily to avoid dependency issues)
self.kernel_agent = None
Expand All @@ -85,7 +54,9 @@ def _get_kernel_agent(self):
# Import KernelAgent from the submodule
import sys

kernel_agent_path = os.path.join(os.path.dirname(__file__), "..", "KernelAgent")
kernel_agent_path = os.path.join(
os.path.dirname(__file__), "..", "..", "KernelAgent"
)
if kernel_agent_path not in sys.path:
sys.path.insert(0, os.path.abspath(kernel_agent_path))

Expand All @@ -101,7 +72,7 @@ def _get_kernel_agent(self):
max_rounds=self.max_rounds,
)

print(f"✓ KernelAgent initialized with log directory: {agent_log_dir}")
logger.info(f"✓ KernelAgent initialized with log directory: {agent_log_dir}")

except ImportError as e:
raise ImportError(
Expand Down Expand Up @@ -203,12 +174,45 @@ def compile_kernel_from_string(
else:
full_code = self._prepare_torch_code(adapted_code)

# Save the kernel to file
kernel_file = os.path.join(self.kernels_dir, f"{op_name}_kernel.py")
# Use PR #90 directory structure
clean_name = clean_op_name_for_directory(op_name)
op_dir = os.path.join(self.kernels_dir, clean_name)
os.makedirs(op_dir, exist_ok=True)

# Determine version number
existing_versions = [
f
for f in os.listdir(op_dir)
if f.startswith(f"{clean_name}_implementation_v") and f.endswith(".py")
]
version = len(existing_versions) + 1

# Save the kernel to file with proper naming
kernel_file = os.path.join(op_dir, f"{clean_name}_implementation_v{version}.py")
with open(kernel_file, "w") as f:
f.write(full_code)

print(f"Saved KernelAgent kernel to: {kernel_file}")
logger.debug(f"Saved KernelAgent kernel to: {kernel_file}")

# Create or update README for the operation
readme_path = os.path.join(op_dir, "README.md")
readme_content = f"""# {op_name}

Generated by KernelAgent

## Implementation

- `{clean_name}_implementation_v{version}.py` - Generated on {datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}

## Usage

This kernel can be used with the DirectoryBackend:
```bash
python BackendBench/scripts/main.py --suite torchbench --backend directory --ops {op_name}
```
"""
with open(readme_path, "w") as f:
f.write(readme_content)

# Import and compile the kernel
spec = importlib.util.spec_from_file_location(f"kernel_agent_{op_name}", kernel_file)
Expand Down Expand Up @@ -259,18 +263,158 @@ def add_kernel(self, op, kernel_code: str, op_name: str):
compiled_kernel = self.compile_kernel_from_string(kernel_code, op_name, attempt=1)
self.compiled_kernels[op] = compiled_kernel

# Save the original KernelAgent code as well
original_file = os.path.join(self.kernels_dir, f"{op_name}_original_kernel_agent.py")
with open(original_file, "w") as f:
f.write(kernel_code)
def _create_test_code_from_backendbench(self, op, op_name: str, test_cases) -> str:
"""
Convert BackendBench test cases to KernelAgent-compatible test code.

def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
Args:
op: PyTorch operation
op_name: Operation name
test_cases: BackendBench test cases

Returns:
Test code string for KernelAgent, or None if no test cases
"""
test_list = list(test_cases) if test_cases else []
if not test_list:
return None

logger.debug(f" Using {len(test_list)} BackendBench test cases")

# Use a few representative test cases (not all, to avoid overwhelming the LLM)
max_tests = min(5, len(test_list))

# Import the serialization utility
from BackendBench.utils import serialize_args

test_code = f'''import torch
import torch.nn.functional as F
import re

def _deserialize_tensor(match):
"""Convert T([shape], dtype) to appropriate torch tensor creation"""
# Parse the T(...) format
content = match.group(1)
parts = [p.strip() for p in content.split(', ')]

# Extract shape (first part)
shape_str = parts[0]

# Extract dtype (second part)
dtype_str = parts[1]

# Handle stride if present (third part)
# For now, we ignore stride and create contiguous tensors

# Convert dtype abbreviations to torch dtypes
dtype_map = {{
'bf16': 'torch.bfloat16',
'f64': 'torch.float64',
'f32': 'torch.float32',
'f16': 'torch.float16',
'c32': 'torch.complex32',
'c64': 'torch.complex64',
'c128': 'torch.complex128',
'i8': 'torch.int8',
'i16': 'torch.int16',
'i32': 'torch.int32',
'i64': 'torch.int64',
'b8': 'torch.bool',
'u8': 'torch.uint8',
}}

torch_dtype = dtype_map.get(dtype_str, 'torch.float32')

# Choose appropriate tensor creation based on dtype
if dtype_str in ['b8']: # Boolean
return f"torch.randint(0, 2, {{shape_str}}, dtype={{torch_dtype}}, device='cuda').bool()"
elif dtype_str in ['i8', 'i16', 'i32', 'i64', 'u8']: # Integer types
return f"torch.randint(0, 10, {{shape_str}}, dtype={{torch_dtype}}, device='cuda')"
elif dtype_str in ['c32', 'c64', 'c128']: # Complex types
return f"torch.randn({{shape_str}}, dtype={{torch_dtype}}, device='cuda')"
else: # Float types
return f"torch.randn({{shape_str}}, dtype={{torch_dtype}}, device='cuda')"

def deserialize_test_args(serialized_str):
"""Convert serialized args string to actual args and kwargs"""
# Replace T(...) with torch.randn(...)
pattern = r'T\(([^)]+)\)'
deserialized = re.sub(pattern, _deserialize_tensor, serialized_str)

# The serialized format is: (args_tuple, kwargs_dict)
# Evaluate to get the tuple
full_data = eval(deserialized)

# Extract args and kwargs
if isinstance(full_data, tuple) and len(full_data) == 2:
args, kwargs = full_data
return list(args), kwargs
else:
# Handle case where there's only args
return list(full_data), {{}}

def test_kernel():
"""Test the {op_name} kernel using BackendBench test cases."""
from kernel import kernel_function

all_passed = True
failed_tests = []

'''

for i, test in enumerate(test_list[:max_tests]):
# Use BackendBench's serialization format
serialized_args = serialize_args(test.args, test.kwargs)

test_code += f" # Test case {i + 1} from BackendBench\n"
test_code += " try:\n"
test_code += " # Deserialize the test arguments\n"
test_code += f' serialized = """{serialized_args}"""\n'
test_code += " args, kwargs = deserialize_test_args(serialized)\n"

# Test execution
op_str = str(op).replace("OpOverload", "").replace("OpOverloadPacket", "")
test_code += f"""
# Get reference result from PyTorch
ref_result = torch.ops.{op_str}(*args, **kwargs)

# Get result from our kernel
kernel_result = kernel_function(*args, **kwargs)

# Compare results
torch.testing.assert_close(ref_result, kernel_result, rtol=1e-2, atol=1e-2)
print(f"Test case {i + 1} passed!")

except Exception as e:
print(f"Test case {i + 1} failed: {{e}}")
failed_tests.append({i + 1})
all_passed = False
"""

test_code += """
if all_passed:
print("All BackendBench tests passed!")
else:
print(f"Failed tests: {failed_tests}")

return all_passed

if __name__ == "__main__":
import sys
success = test_kernel()
sys.exit(0 if success else 1)
"""

return test_code

def generate_kernel_with_agent(self, op, op_name: str, test_cases=None) -> tuple[str, bool]:
"""
Generate a kernel using KernelAgent's sophisticated generation system.

Args:
op: PyTorch operation
op_name: Operation name
test_cases: Optional BackendBench test cases to use for validation

Returns:
tuple: (kernel_code, success)
Expand All @@ -281,43 +425,38 @@ def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
# Create problem description
problem_description = self._create_problem_description_from_op(op, op_name)

print(
# Create test code from BackendBench tests if provided
test_code = None
if test_cases:
test_code = self._create_test_code_from_backendbench(op, op_name, test_cases)

logger.info(
f"🚀 Generating {op_name} kernel with KernelAgent (parallel workers + refinement)"
)

# Generate kernel using KernelAgent
result = agent.generate_kernel(
problem_description=problem_description,
test_code=None, # Let KernelAgent auto-generate the test
test_code=test_code, # Use provided tests or None (auto-generate)
)

if result["success"]:
print(f"✅ KernelAgent succeeded for {op_name}!")
print(
logger.info(f"✅ KernelAgent succeeded for {op_name}!")
logger.info(
f" Worker {result['worker_id']} found solution in {result['rounds']} rounds"
)
print(f" Session: {result['session_dir']}")
logger.info(f" Session: {result['session_dir']}")

# Copy the session directory to our kernels directory for preservation
import shutil

session_name = os.path.basename(result["session_dir"])
preserved_session = os.path.join(
self.kernels_dir, f"{op_name}_session_{session_name}"
)
try:
shutil.copytree(result["session_dir"], preserved_session)
print(f" Session preserved: {preserved_session}")
except Exception as e:
print(f" Warning: Could not preserve session: {e}")
# Log session directory for reference
logger.debug(f" Session directory: {result['session_dir']}")

return result["kernel_code"], True
else:
print(f"❌ KernelAgent failed for {op_name}: {result['message']}")
logger.error(f"❌ KernelAgent failed for {op_name}: {result['message']}")
return "", False

except Exception as e:
print(f"❌ KernelAgent error for {op_name}: {e}")
logger.error(f"❌ KernelAgent error for {op_name}: {e}")
return "", False

def __getitem__(self, key):
Expand Down
Loading