Skip to content

Commit 9277c60

Browse files
authored
Installable backends (#27)
1 parent d49a167 commit 9277c60

File tree

11 files changed

+191
-19
lines changed

11 files changed

+191
-19
lines changed

.github/workflows/ruff.yml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ name: Ruff
22

33
on:
44
push:
5+
branches:
6+
- main
57
pull_request:
8+
branches:
9+
- main
610
jobs:
711
ruff:
812
runs-on: ubuntu-latest
@@ -14,8 +18,8 @@ jobs:
1418
with:
1519
python-version: '3.x'
1620

17-
- name: Install ruff
18-
run: pip install ruff==0.12.1
21+
- name: Install package with dev dependencies
22+
run: pip install -e .[dev]
1923

2024
- name: Run ruff check
2125
run: ruff check .

.github/workflows/smoke-test.yml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@ name: Smoke Test
22

33
on:
44
push:
5+
branches:
6+
- main
57
pull_request:
8+
branches:
9+
- main
610

711
jobs:
812
smoke-test:
@@ -16,11 +20,14 @@ jobs:
1620
with:
1721
python-version: '3.x'
1822

19-
- name: Install dependencies
23+
- name: Install package and dependencies
2024
run: |
21-
pip install -r requirements.txt
22-
pip install -r requirements-dev.txt
25+
pip install -e .[dev]
2326
2427
- name: Run smoke test
2528
run: |
26-
PYTHONPATH=. pytest test/
29+
python -m BackendBench.scripts.main --suite smoke --backend aten
30+
31+
- name: Run pytest tests
32+
run: |
33+
pytest test/

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ __pycache__/
33
.vscode/
44
.ruff_cache/
55
generated_kernels/
6+
backendbench.egg-info/
67
CLAUDE.md
78
venv/
89
ops/

BackendBench/__init__.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
BackendBench: A PyTorch backend evaluation framework with monkey patching support.
3+
4+
Import this module to automatically monkey patch PyTorch operations with custom backends.
5+
"""
6+
7+
import os
8+
9+
from .backends import AtenBackend, FlagGemsBackend
10+
11+
12+
class BackendRegistry:
13+
"""Registry for managing different PyTorch backends."""
14+
15+
def __init__(self):
16+
self._current_backend = None
17+
self._original_ops = {}
18+
self._patched = False
19+
20+
def register_backend(self, backend_name: str, backend_instance=None):
21+
"""Register and activate a backend."""
22+
if backend_instance is None:
23+
backend_instance = self._create_backend(backend_name)
24+
25+
if self._patched:
26+
self.unpatch()
27+
28+
self._current_backend = backend_instance
29+
self._patch_torch_ops()
30+
31+
def _create_backend(self, backend_name: str):
32+
"""Create a backend instance."""
33+
backends = {"aten": AtenBackend, "flag_gems": FlagGemsBackend}
34+
35+
if backend_name not in backends:
36+
raise ValueError(f"Unknown backend: {backend_name}. Available: {list(backends.keys())}")
37+
38+
return backends[backend_name]()
39+
40+
def _patch_torch_ops(self):
41+
"""Monkey patch torch operations with current backend."""
42+
if self._current_backend is None:
43+
return
44+
45+
# Get all torch ops that the backend supports
46+
if hasattr(self._current_backend, "ops"):
47+
for torch_op, backend_impl in self._current_backend.ops.items():
48+
if torch_op not in self._original_ops:
49+
self._original_ops[torch_op] = torch_op.default
50+
torch_op.default = backend_impl
51+
52+
self._patched = True
53+
print(
54+
f"BackendBench: Monkey patched {len(self._original_ops)} operations with {self._current_backend.name} backend"
55+
)
56+
57+
def unpatch(self):
58+
"""Restore original torch operations."""
59+
if not self._patched:
60+
return
61+
62+
for torch_op, original_impl in self._original_ops.items():
63+
torch_op.default = original_impl
64+
65+
self._original_ops.clear()
66+
self._patched = False
67+
print("BackendBench: Restored original PyTorch operations")
68+
69+
def get_current_backend(self):
70+
"""Get the currently active backend."""
71+
return self._current_backend
72+
73+
def is_patched(self):
74+
"""Check if operations are currently patched."""
75+
return self._patched
76+
77+
78+
# Global registry instance
79+
_registry = BackendRegistry()
80+
81+
82+
def use_backend(backend_name: str, backend_instance=None):
83+
"""
84+
Switch to a different backend.
85+
86+
Args:
87+
backend_name: Name of the backend ('aten', 'flag_gems')
88+
backend_instance: Optional pre-configured backend instance
89+
"""
90+
_registry.register_backend(backend_name, backend_instance)
91+
92+
93+
def get_backend():
94+
"""Get the currently active backend."""
95+
return _registry.get_current_backend()
96+
97+
98+
def restore_pytorch():
99+
"""Restore original PyTorch operations."""
100+
_registry.unpatch()
101+
102+
103+
def is_patched():
104+
"""Check if BackendBench is currently patching operations."""
105+
return _registry.is_patched()
106+
107+
108+
# Auto-configuration based on environment variables
109+
def _auto_configure():
110+
"""Auto-configure backend based on environment variables."""
111+
backend_name = os.getenv("BACKENDBENCH_BACKEND", "aten")
112+
113+
try:
114+
use_backend(backend_name)
115+
except Exception as e:
116+
print(f"Warning: Failed to initialize {backend_name} backend: {e}")
117+
print("Falling back to aten backend")
118+
use_backend("aten")
119+
120+
121+
# Auto-configure on import unless explicitly disabled
122+
if os.getenv("BACKENDBENCH_NO_AUTO_PATCH", "").lower() not in ("1", "true", "yes"):
123+
_auto_configure()

BackendBench/scripts/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Scripts module for BackendBench
File renamed without changes.

pyproject.toml

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,49 @@
1+
[build-system]
2+
requires = ["setuptools>=61.0", "wheel"]
3+
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "backendbench"
7+
version = "0.1.0"
8+
description = "A PyTorch backend evaluation suite"
9+
readme = "README.md"
10+
requires-python = ">=3.10"
11+
classifiers = [
12+
"Development Status :: 3 - Alpha",
13+
"Intended Audience :: Developers",
14+
"License :: OSI Approved :: MIT License",
15+
"Programming Language :: Python :: 3",
16+
"Programming Language :: Python :: 3.10",
17+
"Programming Language :: Python :: 3.11",
18+
]
19+
dependencies = [
20+
"torch",
21+
"click",
22+
"numpy",
23+
"expecttest",
24+
"anthropic>=0.34.0",
25+
"pytest",
26+
"requests",
27+
]
28+
29+
[project.optional-dependencies]
30+
dev = [
31+
"pytest",
32+
"pytest-cov",
33+
"pytest-mock",
34+
"pytest-timeout",
35+
"ruff==0.12.1",
36+
]
37+
flaggems = [
38+
"flag_gems",
39+
]
40+
41+
[project.scripts]
42+
backendbench = "BackendBench.scripts.main:cli"
43+
44+
[tool.setuptools.packages.find]
45+
include = ["BackendBench*"]
46+
147
[tool.ruff]
248
line-length = 100
349

requirements-dev.txt

Lines changed: 0 additions & 5 deletions
This file was deleted.

requirements.txt

Lines changed: 0 additions & 7 deletions
This file was deleted.

0 commit comments

Comments
 (0)