Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
5b57a8c
feat: Integrate KernelAgent with BackendBench test cases
Laurawly Aug 22, 2025
ca194c8
fix: Change filter logic to use exact operation name matching
Laurawly Aug 22, 2025
8ba8bd3
feat: Add enhanced KernelAgent run scripts with result organization
Laurawly Aug 22, 2025
9084d45
style: Run ruff format on core_torchbench_ops.py
Laurawly Aug 22, 2025
58ed0a7
style: Run ruff format on kernel_agent.py and data_loaders.py
Laurawly Aug 22, 2025
3f5d5a6
chore: Add license header to core_torchbench_ops.py
Laurawly Aug 22, 2025
96f2727
refactor: Address PR reviews - use PR #90 directory structure
Laurawly Aug 23, 2025
d38491c
feat: Add score tracking to run_kernel_agent.py
Laurawly Aug 23, 2025
e7050c0
feat: Add FP16/BF16 filtering and Triton-friendly operation classific…
Laurawly Aug 23, 2025
ea45857
feat: Complete TorchBench operation categorization and KernelAgent in…
Laurawly Aug 24, 2025
5dafec7
Merge main branch: Add verbose mode and untestable operators
Laurawly Aug 24, 2025
061b57c
fix: Correct syntax error after merge
Laurawly Aug 24, 2025
a7fc8dc
feat: Use BackendBench serialization format for KernelAgent test gene…
Laurawly Aug 24, 2025
8af16e7
feat: Add KernelAgent-generated Triton kernels for 43 operations
Laurawly Sep 1, 2025
c400448
refactor: Remove README.md files from generated kernel folders
Laurawly Sep 1, 2025
bca3ec4
more ops added
Laurawly Sep 2, 2025
57a3e8a
Merge fix-kernelagent-tests: Resolve conflicts in hardswish and maxim…
Laurawly Sep 3, 2025
8cd28d8
Remove multi-version kernel implementations
Laurawly Sep 3, 2025
22783ad
feat: Add 26 new KernelAgent-generated Triton kernels
Laurawly Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 99 additions & 3 deletions BackendBench/backends/kernel_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def _get_kernel_agent(self):
# Import KernelAgent from the submodule
import sys

kernel_agent_path = os.path.join(os.path.dirname(__file__), "..", "KernelAgent")
kernel_agent_path = os.path.join(
os.path.dirname(__file__), "..", "..", "KernelAgent"
)
if kernel_agent_path not in sys.path:
sys.path.insert(0, os.path.abspath(kernel_agent_path))

Expand Down Expand Up @@ -264,13 +266,102 @@ def add_kernel(self, op, kernel_code: str, op_name: str):
with open(original_file, "w") as f:
f.write(kernel_code)

def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
def _create_test_code_from_backendbench(self, op, op_name: str, test_cases) -> str:
"""
Convert BackendBench test cases to KernelAgent-compatible test code.

Args:
op: PyTorch operation
op_name: Operation name
test_cases: BackendBench test cases

Returns:
Test code string for KernelAgent, or None if no test cases
"""
test_list = list(test_cases) if test_cases else []
if not test_list:
return None

print(f" Using {len(test_list)} BackendBench test cases")

# Use a few representative test cases (not all, to avoid overwhelming the LLM)
max_tests = min(5, len(test_list))

test_code = f'''import torch
import torch.nn.functional as F

def test_kernel():
"""Test the {op_name} kernel using BackendBench test cases."""
from kernel import kernel_function

all_passed = True
failed_tests = []

'''

for i, test in enumerate(test_list[:max_tests]):
test_code += f" # Test case {i + 1} from BackendBench\n"
test_code += " try:\n"

# Build args
test_code += " args = [\n"
for arg in test.args:
if hasattr(arg, "shape") and hasattr(arg, "dtype") and hasattr(arg, "device"):
# Recreate tensor with same properties
test_code += f" torch.randn({list(arg.shape)}, dtype={arg.dtype}, device='{arg.device}'),\n"
else:
test_code += f" {repr(arg)},\n"
test_code += " ]\n"

# Add kwargs
if test.kwargs:
test_code += f" kwargs = {repr(test.kwargs)}\n"
else:
test_code += " kwargs = {}\n"

# Test execution
op_str = str(op).replace("OpOverload", "").replace("OpOverloadPacket", "")
test_code += f"""
# Get reference result from PyTorch
ref_result = torch.ops.{op_str}(*args, **kwargs)

# Get result from our kernel
kernel_result = kernel_function(*args, **kwargs)

# Compare results
torch.testing.assert_close(ref_result, kernel_result, rtol=1e-2, atol=1e-2)
print(f"Test case {i + 1} passed!")

except Exception as e:
print(f"Test case {i + 1} failed: {{e}}")
failed_tests.append({i + 1})
all_passed = False
"""

test_code += """
if all_passed:
print("All BackendBench tests passed!")
else:
print(f"Failed tests: {failed_tests}")

return all_passed

if __name__ == "__main__":
import sys
success = test_kernel()
sys.exit(0 if success else 1)
"""

return test_code

def generate_kernel_with_agent(self, op, op_name: str, test_cases=None) -> tuple[str, bool]:
"""
Generate a kernel using KernelAgent's sophisticated generation system.

Args:
op: PyTorch operation
op_name: Operation name
test_cases: Optional BackendBench test cases to use for validation

Returns:
tuple: (kernel_code, success)
Expand All @@ -281,14 +372,19 @@ def generate_kernel_with_agent(self, op, op_name: str) -> tuple[str, bool]:
# Create problem description
problem_description = self._create_problem_description_from_op(op, op_name)

# Create test code from BackendBench tests if provided
test_code = None
if test_cases:
test_code = self._create_test_code_from_backendbench(op, op_name, test_cases)

print(
f"🚀 Generating {op_name} kernel with KernelAgent (parallel workers + refinement)"
)

# Generate kernel using KernelAgent
result = agent.generate_kernel(
problem_description=problem_description,
test_code=None, # Let KernelAgent auto-generate the test
test_code=test_code, # Use provided tests or None (auto-generate)
)

if result["success"]:
Expand Down
13 changes: 11 additions & 2 deletions BackendBench/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import requests
import torch
from BackendBench.utils import cleanup_memory_and_gpu, deserialize_args
from BackendBench.scripts.pytorch_operators import extract_operator_name
from tqdm import tqdm


Expand Down Expand Up @@ -63,7 +64,7 @@ def _parse_trace_file(filename: str, filter: Optional[List[str]] = None) -> List
args_str = m.group(1)
cnt = int(m.group(0).split(",")[0].split(":")[1])

if filter is None or any(f in op for f in filter):
if filter is None or extract_operator_name(op) in filter:
args, kwargs = deserialize_args(args_str)
size = _args_size(args) + _args_size(list(kwargs.values()))
size = size / (1024 * 1024) # Convert to MB
Expand Down Expand Up @@ -212,7 +213,15 @@ def _load_from_parquet(

# Apply filter if provided
if filter:
mask = df["op_name"].apply(lambda op: any(f in op for f in filter))
# Import the function to extract operation names
from BackendBench.scripts.pytorch_operators import extract_operator_name

# Extract operation names and do exact matching
def matches_filter(op_full_name):
op_name = extract_operator_name(op_full_name)
return op_name in filter

mask = df["op_name"].apply(matches_filter)
df = df[mask]

return df.to_dict("records")
Expand Down
4 changes: 3 additions & 1 deletion BackendBench/scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,9 @@ def setup_kernel_agent_backend(kernel_agent_backend, suite, num_workers=4, max_r
print(f" Using {num_workers} parallel workers with up to {max_rounds} rounds each")

# Generate kernel using KernelAgent's sophisticated system
kernel_code, success = kernel_agent_backend.generate_kernel_with_agent(op, op_name)
kernel_code, success = kernel_agent_backend.generate_kernel_with_agent(
op, op_name, test_cases=op_test.correctness_tests
)

if success:
try:
Expand Down
100 changes: 100 additions & 0 deletions core_torchbench_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

"""
The 77 core PyTorch operators that appear in TorchBench traces.
These are the high-priority operations for KernelAgent's first release.
"""

CORE_TORCHBENCH_OPS = [
"abs",
"_adaptive_avg_pool2d",
"_adaptive_avg_pool2d_backward",
"add",
"addmm",
"any",
"avg_pool2d",
"avg_pool2d_backward",
"bitwise_and",
"bitwise_not",
"bitwise_xor",
"bmm",
"cat",
"clamp",
"clone",
"col2im",
"constant_pad_nd",
"convolution",
"convolution_backward",
"cos",
"cumsum",
"div",
"elu",
"eq",
"erf",
"exp",
"flip",
"floor",
"fmod",
"ge",
"gelu",
"grid_sampler_2d",
"gt",
"hardtanh",
"isinf",
"isnan",
"le",
"leaky_relu",
"log2",
"_log_softmax",
"lt",
"max",
"maximum",
"max_pool2d_with_indices",
"max_pool2d_with_indices_backward",
"mean",
"min",
"minimum",
"mm",
"mul",
"native_group_norm",
"native_group_norm_backward",
"native_layer_norm",
"ne",
"neg",
"nonzero",
"pow",
"reciprocal",
"reflection_pad2d",
"relu",
"remainder",
"repeat",
"round",
"rsqrt",
"sigmoid",
"sin",
"_softmax",
"split_with_sizes",
"sqrt",
"sub",
"sum",
"tanh",
"_to_copy",
"topk",
"upsample_bilinear2d",
"upsample_nearest2d",
"where",
]

# Some of these ops might have variants or different names in the actual op registry
# This mapping helps handle common variations
OP_NAME_VARIATIONS = {
"_adaptive_avg_pool2d": ["adaptive_avg_pool2d"],
"_adaptive_avg_pool2d_backward": ["adaptive_avg_pool2d_backward"],
"_log_softmax": ["log_softmax"],
"_softmax": ["softmax"],
"_to_copy": ["to_copy", "to"],
}
Loading