Skip to content

Commit 6dfe2be

Browse files
authored
Refactor backends to one file per backend (#47)
1 parent 4620487 commit 6dfe2be

File tree

8 files changed

+979
-926
lines changed

8 files changed

+979
-926
lines changed

BackendBench/backends.py

Lines changed: 0 additions & 926 deletions
This file was deleted.

BackendBench/backends/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
"""
2+
BackendBench backends submodule.
3+
4+
This module provides various backend implementations for PyTorch operations.
5+
Each backend implements a different strategy for mapping PyTorch operations
6+
to alternative implementations.
7+
"""
8+
9+
from .aten import AtenBackend
10+
from .base import Backend
11+
from .directory import DirectoryBackend
12+
from .flag_gems import FlagGemsBackend
13+
from .kernel_agent import KernelAgentBackend
14+
from .llm import LLMBackend
15+
16+
__all__ = [
17+
"Backend",
18+
"DirectoryBackend",
19+
"AtenBackend",
20+
"FlagGemsBackend",
21+
"LLMBackend",
22+
"KernelAgentBackend",
23+
]

BackendBench/backends/aten.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .base import Backend
2+
3+
4+
class AtenBackend(Backend):
5+
def __init__(self) -> None:
6+
super().__init__("aten")
7+
8+
def __getitem__(self, key):
9+
return key
10+
11+
def __contains__(self, key):
12+
return True

BackendBench/backends/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
class Backend:
2+
def __init__(self, name):
3+
self.name = name

BackendBench/backends/directory.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import importlib.util
2+
import logging
3+
import os
4+
from typing import Callable, Dict
5+
6+
import torch
7+
8+
from .base import Backend
9+
10+
logger = logging.getLogger(__name__)
11+
12+
13+
class DirectoryBackend(Backend):
14+
def __init__(self, ops_dir="generated_kernels"):
15+
super().__init__("directory")
16+
self.ops_dir = ops_dir
17+
self.compiled_kernels: Dict[str, Callable] = {}
18+
self._load_kernels()
19+
20+
def _load_kernels(self):
21+
if not os.path.exists(self.ops_dir):
22+
logger.warning(f"ops directory {self.ops_dir} does not exist")
23+
return
24+
25+
loaded_count = 0
26+
for op_name in os.listdir(self.ops_dir):
27+
op_dir = os.path.join(self.ops_dir, op_name)
28+
if not os.path.isdir(op_dir):
29+
continue
30+
31+
impl_files = [f for f in os.listdir(op_dir) if f.endswith(".py")]
32+
if not impl_files:
33+
logger.warning(f"No Python files found in {op_dir}")
34+
continue
35+
36+
# Use the first implementation file
37+
impl_file = impl_files[0]
38+
impl_path = os.path.join(op_dir, impl_file)
39+
40+
try:
41+
# Load the implementation and map to PyTorch operation
42+
kernel_func = self._load_kernel_from_file(impl_path, op_name)
43+
pytorch_op = self._find_pytorch_op(op_name)
44+
if pytorch_op:
45+
self.compiled_kernels[pytorch_op] = kernel_func
46+
logger.info(f"Loaded {op_name} from {impl_file}")
47+
loaded_count += 1
48+
else:
49+
logger.warning(f"Could not map {op_name} to PyTorch operation")
50+
51+
except Exception as e:
52+
logger.error(f"Error loading {op_name} from {impl_file}: {e}")
53+
54+
logger.info(f"DirectoryBackend loaded {loaded_count} kernels from {self.ops_dir}/")
55+
56+
def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable:
57+
spec = importlib.util.spec_from_file_location(f"op_{op_name}", file_path)
58+
module = importlib.util.module_from_spec(spec)
59+
spec.loader.exec_module(module)
60+
61+
kernel_func_name = f"{op_name}_kernel_impl"
62+
if hasattr(module, kernel_func_name):
63+
return getattr(module, kernel_func_name)
64+
else:
65+
raise ValueError(f"No callable function found in {file_path}")
66+
67+
def _find_pytorch_op(self, op_name: str):
68+
"""Map operation name to PyTorch operation."""
69+
# Try common patterns
70+
try:
71+
return getattr(torch.ops.aten, op_name).default
72+
except AttributeError:
73+
pass
74+
75+
try:
76+
return getattr(torch.ops.aten, op_name).Tensor
77+
except AttributeError:
78+
pass
79+
80+
# Not 100% sure this is right, will need to iterate over all ops
81+
return None
82+
83+
def __getitem__(self, key):
84+
if key in self.compiled_kernels:
85+
return self.compiled_kernels[key]
86+
# Fallback to original operation if not implemented
87+
return key
88+
89+
def __contains__(self, key):
90+
return key in self.compiled_kernels or True # Always claim to contain ops for fallback

BackendBench/backends/flag_gems.py

Lines changed: 292 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)