Skip to content

Add Directory STructure for BackendBench #90

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ __pycache__/
.claude/
.vscode/
.ruff_cache/
generated_kernels/
backendbench.egg-info/
CLAUDE.md
venv/
ops/
uv.lock
pytorch_operator_coverage.csv
.pre-commit-cache/
generated_kernels/
internal_operators.csv
torchbench_operator_folder_mapping.csv
122 changes: 2 additions & 120 deletions BackendBench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,125 +5,7 @@
# LICENSE file in the root directory of this source tree.

"""
BackendBench: A PyTorch backend evaluation framework with monkey patching support.

Import this module to automatically monkey patch PyTorch operations with custom backends.
BackendBench: A PyTorch backend evaluation framework.
"""

import os

from .backends import AtenBackend, FlagGemsBackend


class BackendRegistry:
"""Registry for managing different PyTorch backends."""

def __init__(self):
self._current_backend = None
self._original_ops = {}
self._patched = False

def register_backend(self, backend_name: str, backend_instance=None):
"""Register and activate a backend."""
if backend_instance is None:
backend_instance = self._create_backend(backend_name)

if self._patched:
self.unpatch()

self._current_backend = backend_instance
self._patch_torch_ops()

def _create_backend(self, backend_name: str):
"""Create a backend instance."""
backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend}

if backend_name not in backends:
raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}")

return backends[backend_name]()

def _patch_torch_ops(self):
"""Monkey patch torch operations with current backend."""
if self._current_backend is None:
return

# Get all torch ops that the backend supports
if hasattr(self._current_backend, "ops"):
for torch_op, backend_impl in self._current_backend.ops.items():
if torch_op not in self._original_ops:
self._original_ops[torch_op] = torch_op.default
torch_op.default = backend_impl

self._patched = True
print(
f"BackendBench: Monkey patched {len(self._original_ops)} operations with {self._current_backend.name} backend"
)

def unpatch(self):
"""Restore original torch operations."""
if not self._patched:
return

for torch_op, original_impl in self._original_ops.items():
torch_op.default = original_impl

self._original_ops.clear()
self._patched = False
print("BackendBench: Restored original PyTorch operations")

def get_current_backend(self):
"""Get the currently active backend."""
return self._current_backend

def is_patched(self):
"""Check if operations are currently patched."""
return self._patched


# Global registry instance
_registry = BackendRegistry()


def use_backend(backend_name: str, backend_instance=None):
"""
Switch to a different backend.

Args:
backend_name: Name of the backend ('aten', 'flag_gems')
backend_instance: Optional pre-configured backend instance
"""
_registry.register_backend(backend_name, backend_instance)


def get_backend():
"""Get the currently active backend."""
return _registry.get_current_backend()


def restore_pytorch():
"""Restore original PyTorch operations."""
_registry.unpatch()


def is_patched():
"""Check if BackendBench is currently patching operations."""
return _registry.is_patched()


# Auto-configuration based on environment variables
def _auto_configure():
"""Auto-configure backend based on environment variables."""
backend_name = os.getenv("BACKENDBENCH_BACKEND", "aten")

try:
use_backend(backend_name)
except Exception as e:
print(f"Warning: Failed to initialize {backend_name} backend: {e}")
print("Falling back to aten backend")
use_backend("aten")


# Auto-configure on import unless explicitly disabled
if os.getenv("BACKENDBENCH_NO_AUTO_PATCH", "").lower() not in ("1", "true", "yes"):
_auto_configure()
__version__ = "0.1.0"
77 changes: 52 additions & 25 deletions BackendBench/backends/directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,28 @@ def _load_kernels(self):
if not os.path.isdir(op_dir):
continue

impl_files = [f for f in os.listdir(op_dir) if f.endswith(".py")]
impl_files = [
f
for f in os.listdir(op_dir)
if f.endswith(".py") and f.startswith(f"{op_name}_implementation")
]
if not impl_files:
logger.warning(f"No Python files found in {op_dir}")
logger.debug(f"No implementation files found in {op_dir}")
continue

# Use the first implementation file
impl_file = impl_files[0]
impl_file = sorted(impl_files)[0] # Sort to ensure consistent selection
impl_path = os.path.join(op_dir, impl_file)

try:
# Load the implementation and map to PyTorch operation
kernel_func = self._load_kernel_from_file(impl_path, op_name)
pytorch_op = self._find_pytorch_op(op_name)
if pytorch_op:
self.compiled_kernels[pytorch_op] = kernel_func
logger.info(f"Loaded {op_name} from {impl_file}")
pytorch_ops = self._find_pytorch_ops(op_name)

if pytorch_ops:
for pytorch_op in pytorch_ops:
self.compiled_kernels[pytorch_op] = kernel_func
logger.info(f"Loaded {op_name} from {impl_file} -> {pytorch_op}")
loaded_count += 1
else:
logger.warning(f"Could not map {op_name} to PyTorch operation")
Expand All @@ -68,23 +74,44 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable:
if hasattr(module, kernel_func_name):
return getattr(module, kernel_func_name)
else:
raise ValueError(f"No callable function found in {file_path}")

def _find_pytorch_op(self, op_name: str):
"""Map operation name to PyTorch operation."""
# Try common patterns
try:
return getattr(torch.ops.aten, op_name).default
except AttributeError:
pass

try:
return getattr(torch.ops.aten, op_name).Tensor
except AttributeError:
pass

# Not 100% sure this is right, will need to iterate over all ops
return None
raise ValueError(f"No function named {kernel_func_name} found in {file_path}")

def _find_pytorch_ops(self, op_name: str):
"""Map operation name to PyTorch operations.

Returns a list of PyTorch operations that match the directory name.
This handles the common case where a directory name like 'add' should map
to multiple overloads like add.default, add.Tensor, etc.
"""
matched_ops = []

# Handle suffixed directory names (e.g., add_out -> add.out)
base_name = op_name
suffix = None
if "_" in op_name:
parts = op_name.rsplit("_", 1)
if parts[1] in ["out", "inplace", "scalar"]:
base_name = parts[0]
suffix = parts[1]

# Try to find the operation in torch.ops.aten
if hasattr(torch.ops.aten, base_name):
aten_op = getattr(torch.ops.aten, base_name)

# If we have a specific suffix, try to get that overload
if suffix and hasattr(aten_op, suffix):
matched_ops.append(getattr(aten_op, suffix))
else:
# Otherwise, try common overloads
for overload in ["default", "Tensor", "Scalar", "int", "float"]:
if hasattr(aten_op, overload):
op = getattr(aten_op, overload)
matched_ops.append(op)

# Also check for operations that might be in other namespaces
# This could be extended based on actual usage patterns

return matched_ops

def __getitem__(self, key):
if key in self.compiled_kernels:
Expand All @@ -93,4 +120,4 @@ def __getitem__(self, key):
return key

def __contains__(self, key):
return key in self.compiled_kernels or True # Always claim to contain ops for fallback
return key in self.compiled_kernels
14 changes: 7 additions & 7 deletions BackendBench/scripts/create_simple_test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

def create_relu():
os.makedirs("generated_kernels/relu", exist_ok=True)
with open("generated_kernels/relu/relu_implementation_1.py", "w") as f:
with open("generated_kernels/relu/relu_implementation_v1.py", "w") as f:
f.write('''import torch

def relu_kernel_impl(input):
Expand All @@ -37,7 +37,7 @@ def relu_kernel_impl(input):

def create_add():
os.makedirs("generated_kernels/add", exist_ok=True)
with open("generated_kernels/add/add_implementation_1.py", "w") as f:
with open("generated_kernels/add/add_implementation_v1.py", "w") as f:
f.write('''import torch

def add_kernel_impl(input, other):
Expand All @@ -56,7 +56,7 @@ def add_kernel_impl(input, other):

def create_mul():
os.makedirs("generated_kernels/mul", exist_ok=True)
with open("generated_kernels/mul/mul_implementation_1.py", "w") as f:
with open("generated_kernels/mul/mul_implementation_v1.py", "w") as f:
f.write('''import torch

def mul_kernel_impl(input, other):
Expand All @@ -75,7 +75,7 @@ def mul_kernel_impl(input, other):

def create_abs():
os.makedirs("generated_kernels/abs", exist_ok=True)
with open("generated_kernels/abs/abs_implementation_1.py", "w") as f:
with open("generated_kernels/abs/abs_implementation_v1.py", "w") as f:
f.write('''import torch

def abs_kernel_impl(input):
Expand All @@ -93,7 +93,7 @@ def abs_kernel_impl(input):

def create_sum():
os.makedirs("generated_kernels/sum", exist_ok=True)
with open("generated_kernels/sum/sum_implementation_1.py", "w") as f:
with open("generated_kernels/sum/sum_implementation_v1.py", "w") as f:
f.write('''import torch

def sum_kernel_impl(input, *args, **kwargs):
Expand Down Expand Up @@ -122,8 +122,8 @@ def main():

logger.info("Created 5 simple kernel implementations in generated_kernels/")
logger.info("Test them individually:")
logger.info(" python generated_kernels/relu/relu_implementation_1.py")
logger.info(" python generated_kernels/add/add_implementation_1.py")
logger.info(" python generated_kernels/relu/relu_implementation_v1.py")
logger.info(" python generated_kernels/add/add_implementation_v1.py")
logger.info(" etc.")
logger.info("Or test all with the backend:")
logger.info(" python test/test_simple_directory_backend.py")
Expand Down
Loading