Skip to content

Commit 1f35066

Browse files
authored
No more test mocking (#58)
1 parent 424a290 commit 1f35066

File tree

3 files changed

+140
-214
lines changed

3 files changed

+140
-214
lines changed

test/test_backends.py

Lines changed: 68 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from unittest.mock import Mock, patch
2-
31
import pytest
42
import torch
53
from BackendBench.backends import (
@@ -38,128 +36,99 @@ def test_aten_backend_contains_op(self):
3836

3937
assert torch.ops.aten.relu.default in backend
4038
assert torch.ops.aten.add.Tensor in backend
41-
42-
fake_op = Mock()
43-
fake_op.__module__ = "fake_module"
44-
assert fake_op in backend # AtenBackend contains everything
39+
assert torch.ops.aten.mul.Tensor in backend
4540

4641
def test_aten_backend_getitem(self):
4742
backend = AtenBackend()
4843

4944
relu_op = torch.ops.aten.relu.default
5045
assert backend[relu_op] == relu_op
5146

52-
fake_op = Mock()
53-
fake_op.__module__ = "fake_module"
54-
assert backend[fake_op] == fake_op # AtenBackend returns the op itself
47+
add_op = torch.ops.aten.add.Tensor
48+
assert backend[add_op] == add_op
5549

5650

5751
class TestFlagGemsBackend:
5852
@pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available")
59-
@patch("BackendBench.backends.flag_gems")
60-
def test_flag_gems_backend_initialization(self, mock_flag_gems):
53+
def test_flag_gems_backend_initialization(self):
6154
backend = FlagGemsBackend()
6255
assert backend.name == "flaggems"
6356
assert isinstance(backend.ops, dict)
6457

6558
@pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available")
66-
@patch("BackendBench.backends.flag_gems")
67-
def test_flag_gems_backend_contains_op(self, mock_flag_gems):
68-
mock_flag_gems.abs = Mock()
69-
59+
def test_flag_gems_backend_contains_op(self):
7060
backend = FlagGemsBackend()
7161

72-
assert torch.ops.aten.abs.default in backend
62+
# Test with actual ops that flag_gems supports
63+
if hasattr(torch.ops.aten, "abs"):
64+
if torch.ops.aten.abs.default in backend:
65+
assert torch.ops.aten.abs.default in backend
7366

74-
fake_op = Mock()
75-
fake_op.__str__ = Mock(return_value="fake_op")
76-
assert fake_op not in backend
67+
# Test with an op that might not be in flag_gems
68+
unsupported_op = (
69+
torch.ops.aten.special_log_ndtr.default
70+
if hasattr(torch.ops.aten, "special_log_ndtr")
71+
else None
72+
)
73+
if unsupported_op:
74+
assert unsupported_op not in backend
7775

7876
@pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available")
79-
@patch("BackendBench.backends.flag_gems")
80-
def test_flag_gems_backend_getitem(self, mock_flag_gems):
81-
mock_abs_impl = Mock()
82-
mock_flag_gems.ops.abs = mock_abs_impl
83-
77+
def test_flag_gems_backend_getitem(self):
8478
backend = FlagGemsBackend()
8579

86-
assert backend[torch.ops.aten.abs.default] == mock_abs_impl
80+
# Test with an op that should exist
81+
if hasattr(torch.ops.aten, "abs") and torch.ops.aten.abs.default in backend:
82+
impl = backend[torch.ops.aten.abs.default]
83+
assert impl is not None
8784

88-
fake_op = Mock()
89-
fake_op.__str__ = Mock(return_value="fake_op")
90-
with pytest.raises(KeyError):
91-
_ = backend[fake_op]
85+
# Test with an op that doesn't exist in flag_gems
86+
unsupported_op = (
87+
torch.ops.aten.special_log_ndtr.default
88+
if hasattr(torch.ops.aten, "special_log_ndtr")
89+
else None
90+
)
91+
if unsupported_op and unsupported_op not in backend:
92+
with pytest.raises(KeyError):
93+
_ = backend[unsupported_op]
9294

9395

9496
class TestLLMBackend:
9597
def test_llm_backend_initialization(self):
96-
with (
97-
patch("os.makedirs"),
98-
patch("builtins.open"),
99-
patch("datetime.datetime") as mock_datetime,
100-
):
101-
mock_datetime.now.return_value.strftime.return_value = "20250721_204542"
102-
backend = LLMBackend()
103-
assert backend.name == "llm"
104-
assert "generated_kernels/run_" in backend.kernels_dir
105-
assert isinstance(backend.compiled_kernels, dict)
106-
107-
@pytest.mark.skip(
108-
reason="Complex file I/O mocking needed - test requires full file system interaction"
109-
)
110-
def test_llm_backend_add_kernel(self):
111-
with (
112-
patch("os.makedirs"),
113-
patch("builtins.open"),
114-
patch("datetime.datetime") as mock_datetime,
115-
):
116-
mock_datetime.now.return_value.strftime.return_value = "20250721_204542"
117-
backend = LLMBackend()
118-
119-
mock_op = Mock()
120-
mock_op.__name__ = "test_op"
121-
122-
kernel_code = """
123-
def test_kernel(x):
124-
return x + 1
125-
"""
126-
127-
with patch("builtins.open", create=True) as mock_open:
128-
backend.add_kernel(mock_op, kernel_code, "test_op")
129-
130-
mock_open.assert_called()
131-
132-
assert mock_op in backend
133-
134-
@pytest.mark.skip(
135-
reason="Complex file I/O mocking needed - test requires full file system interaction"
136-
)
137-
def test_llm_backend_test_kernel_correctness(self):
138-
with (
139-
patch("os.makedirs"),
140-
patch("builtins.open"),
141-
patch("datetime.datetime") as mock_datetime,
142-
):
143-
mock_datetime.now.return_value.strftime.return_value = "20250721_204542"
144-
backend = LLMBackend()
98+
backend = LLMBackend()
99+
assert backend.name == "llm"
100+
assert "generated_kernels/run_" in backend.kernels_dir
101+
assert isinstance(backend.compiled_kernels, dict)
145102

146-
mock_op = Mock(return_value=torch.tensor([2.0]))
147-
148-
kernel_code = """
149-
def generated_kernel(x):
150-
return x + 1
103+
@pytest.mark.skip(reason="Requires Triton for kernel compilation")
104+
def test_llm_backend_add_kernel(self):
105+
backend = LLMBackend()
106+
107+
# Use a real torch op for testing
108+
test_op = torch.ops.aten.relu.default
109+
110+
kernel_code = """
111+
@triton.jit
112+
def relu_kernel(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
113+
pid = tl.program_id(0)
114+
block_start = pid * BLOCK_SIZE
115+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
116+
mask = offsets < n_elements
117+
x = tl.load(x_ptr + offsets, mask=mask)
118+
output = tl.maximum(x, 0)
119+
tl.store(output_ptr + offsets, output, mask=mask)
120+
121+
def generated_relu(x):
122+
output = torch.empty_like(x)
123+
n_elements = output.numel()
124+
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
125+
relu_kernel[grid](x, output, n_elements, BLOCK_SIZE=1024)
126+
return output
151127
"""
152128

153-
mock_test = Mock()
154-
mock_test.args = [torch.tensor([1.0])]
155-
mock_test.kwargs = {}
156-
157-
with patch("builtins.open", create=True):
158-
is_correct, feedback = backend.test_kernel_correctness(
159-
mock_op, kernel_code, [mock_test], attempt=1
160-
)
129+
backend.add_kernel(test_op, kernel_code, "relu")
161130

162-
assert is_correct is True
131+
assert test_op in backend
163132

164133

165134
class TestKernelAgentBackend:
@@ -180,42 +149,20 @@ def test_kernel_agent_backend_set_config(self):
180149
assert backend.num_workers == 8
181150
assert backend.max_rounds == 20
182151

183-
@pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available")
184-
def test_kernel_agent_backend_generate_kernel(self):
185-
with (
186-
patch("triton_kernel_agent.TritonKernelAgent") as mock_kernel_agent_class,
187-
):
188-
backend = KernelAgentBackend()
189-
190-
mock_agent = Mock()
191-
mock_kernel_agent_class.return_value = mock_agent
192-
193-
mock_agent.generate_kernel.return_value = {
194-
"success": True,
195-
"kernel_code": "def kernel(): pass",
196-
"rounds": 1,
197-
"session_dir": "test_session_dir",
198-
"worker_id": 0,
199-
}
200-
201-
mock_op = Mock()
202-
mock_op.__str__ = Mock(return_value="test_op")
203-
with patch("builtins.open", create=True):
204-
kernel_code, success = backend.generate_kernel_with_agent(mock_op, "test_op")
205-
assert success is True
206-
assert kernel_code == "def kernel(): pass"
207-
mock_kernel_agent_class.assert_called_once()
208-
209152

210153
class TestBackendIntegration:
211-
@pytest.mark.skipif(not HAS_FLAG_GEMS, reason="flag_gems not available")
212154
def test_backend_polymorphism(self):
213155
backends = []
214156
backends.append(AtenBackend())
215-
with patch("BackendBench.backends.flag_gems"):
157+
158+
if HAS_FLAG_GEMS:
216159
backends.append(FlagGemsBackend())
160+
217161
backends.append(LLMBackend())
218-
backends.append(KernelAgentBackend())
162+
163+
if HAS_KERNEL_AGENT:
164+
backends.append(KernelAgentBackend())
165+
219166
for backend in backends:
220167
assert hasattr(backend, "name")
221168
assert hasattr(backend, "__contains__")

test/test_directory_backend.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
@pytest.fixture(scope="module")
1616
def backend():
17-
# Ensure generated_kernels directory exists for CI
18-
if not os.path.exists("generated_kernels"):
19-
# Import and run the existing script
17+
expected_dirs = ["relu", "add", "mul", "abs", "sum"]
18+
missing_dirs = [d for d in expected_dirs if not os.path.isdir(f"generated_kernels/{d}")]
19+
20+
if missing_dirs:
2021
import subprocess
2122

2223
subprocess.run(

0 commit comments

Comments
 (0)