Skip to content

Commit 9d53e89

Browse files
authored
Directory Backend (#30)
1 parent 87d73b4 commit 9d53e89

File tree

7 files changed

+421
-2
lines changed

7 files changed

+421
-2
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,6 @@ __pycache__/
33
.vscode/
44
.ruff_cache/
55
generated_kernels/
6+
CLAUDE.md
67
venv/
7-
CLAUDE.md
8+
ops/

BackendBench/backends.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,98 @@
11
import os
22
import importlib.util
3+
import logging
34
from typing import Dict, Callable, List
45

6+
logger = logging.getLogger(__name__)
7+
58

69
class Backend:
710
def __init__(self, name):
811
self.name = name
912

1013

14+
class DirectoryBackend(Backend):
15+
def __init__(self, ops_dir="generated_kernels"):
16+
super().__init__("directory")
17+
self.ops_dir = ops_dir
18+
self.compiled_kernels: Dict[str, Callable] = {}
19+
self._load_kernels()
20+
21+
def _load_kernels(self):
22+
if not os.path.exists(self.ops_dir):
23+
logger.warning(f"ops directory {self.ops_dir} does not exist")
24+
return
25+
26+
loaded_count = 0
27+
for op_name in os.listdir(self.ops_dir):
28+
op_dir = os.path.join(self.ops_dir, op_name)
29+
if not os.path.isdir(op_dir):
30+
continue
31+
32+
impl_files = [f for f in os.listdir(op_dir) if f.endswith(".py")]
33+
if not impl_files:
34+
logger.warning(f"No Python files found in {op_dir}")
35+
continue
36+
37+
# Use the first implementation file
38+
impl_file = impl_files[0]
39+
impl_path = os.path.join(op_dir, impl_file)
40+
41+
try:
42+
# Load the implementation and map to PyTorch operation
43+
kernel_func = self._load_kernel_from_file(impl_path, op_name)
44+
pytorch_op = self._find_pytorch_op(op_name)
45+
if pytorch_op:
46+
self.compiled_kernels[pytorch_op] = kernel_func
47+
logger.info(f"Loaded {op_name} from {impl_file}")
48+
loaded_count += 1
49+
else:
50+
logger.warning(f"Could not map {op_name} to PyTorch operation")
51+
52+
except Exception as e:
53+
logger.error(f"Error loading {op_name} from {impl_file}: {e}")
54+
55+
logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/")
56+
57+
def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable:
58+
spec = importlib.util.spec_from_file_location(f"op_{op_name}", file_path)
59+
module = importlib.util.module_from_spec(spec)
60+
spec.loader.exec_module(module)
61+
62+
kernel_func_name = f"{op_name}_kernel_impl"
63+
if hasattr(module, kernel_func_name):
64+
return getattr(module, kernel_func_name)
65+
else:
66+
raise ValueError(f"No callable function found in {file_path}")
67+
68+
def _find_pytorch_op(self, op_name: str):
69+
"""Map operation name to PyTorch operation."""
70+
import torch
71+
72+
# Try common patterns
73+
try:
74+
return getattr(torch.ops.aten, op_name).default
75+
except AttributeError:
76+
pass
77+
78+
try:
79+
return getattr(torch.ops.aten, op_name).Tensor
80+
except AttributeError:
81+
pass
82+
83+
# Not 100% sure this is right, will need to iterate over all ops
84+
return None
85+
86+
def __getitem__(self, key):
87+
if key in self.compiled_kernels:
88+
return self.compiled_kernels[key]
89+
# Fallback to original operation if not implemented
90+
return key
91+
92+
def __contains__(self, key):
93+
return key in self.compiled_kernels or True # Always claim to contain ops for fallback
94+
95+
1196
class AtenBackend(Backend):
1297
def __init__(self) -> None:
1398
super().__init__("aten")

README.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,97 @@ Run KernelAgent on opinfo tests with a specific operation:
7373
export OPENAI_API_KEY=your_api_key_here
7474
python scripts/main.py --suite opinfo --backend kernel_agent --ops "add"
7575
```
76+
77+
## Directory-Based Kernel Development
78+
79+
BackendBench supports a simple directory structure for manually adding kernel implementations. This is perfect for researchers who want to contribute optimized kernels without dealing with complex generation systems.
80+
81+
### Directory Structure
82+
83+
Create kernels in the following structure:
84+
```
85+
generated_kernels/
86+
├── relu/
87+
│ └── relu_implementation_1.py
88+
├── add/
89+
│ └── add_implementation_1.py
90+
├── mul/
91+
│ └── mul_implementation_1.py
92+
└── ...
93+
```
94+
95+
### How to Add Your Kernels
96+
97+
1. **Create the operation directory:**
98+
```bash
99+
mkdir generated_kernels/{op_name}
100+
```
101+
102+
2. **Create your implementation file:**
103+
```bash
104+
# Example: generated_kernels/relu/relu_implementation_1.py
105+
```
106+
107+
3. **Write your kernel following this template:**
108+
```python
109+
import torch
110+
111+
def {op_name}_kernel_impl(*args, **kwargs):
112+
"""
113+
Your kernel implementation.
114+
Must match the PyTorch operation signature exactly.
115+
"""
116+
# Your implementation here
117+
return result
118+
119+
# Optional: Add a test
120+
if __name__ == "__main__":
121+
pass
122+
```
123+
124+
### Operation Name Mapping
125+
126+
Use these exact directory names for common operations:
127+
- `relu``torch.ops.aten.relu.default`
128+
- `add``torch.ops.aten.add.Tensor`
129+
- `mul``torch.ops.aten.mul.Tensor`
130+
- `div``torch.ops.aten.div.Tensor`
131+
132+
To find the correct name for other operations:
133+
```python
134+
# Find operation name
135+
import torch
136+
op = torch.ops.aten.some_op.some_variant
137+
print(str(op).split('aten.')[-1].split('.')[0]) # Use this as directory name
138+
```
139+
140+
### Example Implementation
141+
142+
Here's a complete example for ReLU:
143+
144+
```python
145+
# generated_kernels/relu/relu_implementation_1.py
146+
import torch
147+
148+
def relu_kernel_impl(input_tensor):
149+
return torch.maximum(input_tensor, torch.zeros_like(input_tensor))
150+
151+
if __name__ == "__main__":
152+
# Test on CPU
153+
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
154+
result = relu_kernel_impl(x)
155+
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
156+
print(f"Test passed: {torch.allclose(result, expected)}")
157+
```
158+
159+
### Testing Your Kernels
160+
161+
Test individual implementations:
162+
```bash
163+
python generated_kernels/relu/relu_implementation_1.py
164+
```
165+
166+
Test with BackendBench:
167+
```bash
168+
python scripts/main.py --suite smoke --backend directory
169+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ click
33
numpy
44
expecttest
55
anthropic>=0.34.0
6+
pytest

scripts/create_simple_test_ops.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Create simple kernel implementations for 5 common operations.
4+
Each just calls the original PyTorch function.
5+
"""
6+
7+
import os
8+
import logging
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
def create_relu():
14+
os.makedirs("generated_kernels/relu", exist_ok=True)
15+
with open("generated_kernels/relu/relu_implementation_1.py", "w") as f:
16+
f.write('''import torch
17+
18+
def relu_kernel_impl(input):
19+
"""Simple ReLU implementation."""
20+
return torch.ops.aten.relu.default(input)
21+
22+
if __name__ == "__main__":
23+
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
24+
result = relu_kernel_impl(x)
25+
expected = torch.tensor([0.0, 0.0, 0.0, 1.0, 2.0])
26+
print(f"ReLU test passed: {torch.allclose(result, expected)}")
27+
''')
28+
logger.info("Created relu implementation")
29+
30+
31+
def create_add():
32+
os.makedirs("generated_kernels/add", exist_ok=True)
33+
with open("generated_kernels/add/add_implementation_1.py", "w") as f:
34+
f.write('''import torch
35+
36+
def add_kernel_impl(input, other):
37+
"""Simple add implementation."""
38+
return torch.ops.aten.add.Tensor(input, other)
39+
40+
if __name__ == "__main__":
41+
a = torch.tensor([1.0, 2.0, 3.0])
42+
b = torch.tensor([4.0, 5.0, 6.0])
43+
result = add_kernel_impl(a, b)
44+
expected = torch.tensor([5.0, 7.0, 9.0])
45+
print(f"Add test passed: {torch.allclose(result, expected)}")
46+
''')
47+
logger.info("Created add implementation")
48+
49+
50+
def create_mul():
51+
os.makedirs("generated_kernels/mul", exist_ok=True)
52+
with open("generated_kernels/mul/mul_implementation_1.py", "w") as f:
53+
f.write('''import torch
54+
55+
def mul_kernel_impl(input, other):
56+
"""Simple mul implementation."""
57+
return torch.ops.aten.mul.Tensor(input, other)
58+
59+
if __name__ == "__main__":
60+
a = torch.tensor([1.0, 2.0, 3.0])
61+
b = torch.tensor([4.0, 5.0, 6.0])
62+
result = mul_kernel_impl(a, b)
63+
expected = torch.tensor([4.0, 10.0, 18.0])
64+
print(f"Mul test passed: {torch.allclose(result, expected)}")
65+
''')
66+
logger.info("Created mul implementation")
67+
68+
69+
def create_abs():
70+
os.makedirs("generated_kernels/abs", exist_ok=True)
71+
with open("generated_kernels/abs/abs_implementation_1.py", "w") as f:
72+
f.write('''import torch
73+
74+
def abs_kernel_impl(input):
75+
"""Simple abs implementation."""
76+
return torch.ops.aten.abs.default(input)
77+
78+
if __name__ == "__main__":
79+
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])
80+
result = abs_kernel_impl(x)
81+
expected = torch.tensor([2.0, 1.0, 0.0, 1.0, 2.0])
82+
print(f"Abs test passed: {torch.allclose(result, expected)}")
83+
''')
84+
logger.info("Created abs implementation")
85+
86+
87+
def create_sum():
88+
os.makedirs("generated_kernels/sum", exist_ok=True)
89+
with open("generated_kernels/sum/sum_implementation_1.py", "w") as f:
90+
f.write('''import torch
91+
92+
def sum_kernel_impl(input, *args, **kwargs):
93+
"""Simple sum implementation."""
94+
return torch.ops.aten.sum.default(input, *args, **kwargs)
95+
96+
if __name__ == "__main__":
97+
x = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
98+
result = sum_kernel_impl(x)
99+
expected = torch.tensor(10.0)
100+
print(f"Sum test passed: {torch.allclose(result, expected)}")
101+
''')
102+
logger.info("Created sum implementation")
103+
104+
105+
def main():
106+
"""Create 5 simple test operations."""
107+
logging.basicConfig(level=logging.INFO, format="%(message)s")
108+
logger.info("Creating simple test implementations...")
109+
110+
create_relu()
111+
create_add()
112+
create_mul()
113+
create_abs()
114+
create_sum()
115+
116+
logger.info("Created 5 simple kernel implementations in generated_kernels/")
117+
logger.info("Test them individually:")
118+
logger.info(" python generated_kernels/relu/relu_implementation_1.py")
119+
logger.info(" python generated_kernels/add/add_implementation_1.py")
120+
logger.info(" etc.")
121+
logger.info("Or test all with the backend:")
122+
logger.info(" python test/test_simple_directory_backend.py")
123+
124+
125+
if __name__ == "__main__":
126+
main()

scripts/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def setup_logging(log_level):
4444
@click.option(
4545
"--backend",
4646
default="aten",
47-
type=click.Choice(["aten", "flag_gems", "llm", "kernel_agent"]),
47+
type=click.Choice(["aten", "flag_gems", "llm", "kernel_agent", "directory"]),
4848
help="Which backend to run",
4949
)
5050
@click.option(
@@ -96,6 +96,7 @@ def cli(
9696
"flag_gems": backends.FlagGemsBackend,
9797
"llm": backends.LLMBackend,
9898
"kernel_agent": backends.KernelAgentBackend,
99+
"directory": backends.DirectoryBackend,
99100
}[backend]()
100101

101102
# For LLM backend, we need to generate kernels first

0 commit comments

Comments
 (0)