Skip to content

Commit 56237fc

Browse files
authored
Fixes to kernel agent backend tests (#46)
1 parent bd4f808 commit 56237fc

File tree

2 files changed

+43
-28
lines changed

2 files changed

+43
-28
lines changed

BackendBench/backends.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
import os
21
import importlib.util
32
import logging
4-
from typing import Dict, Callable, List
3+
import os
4+
from typing import Callable, Dict, List
5+
56
import torch
67

78
logger = logging.getLogger(__name__)
@@ -397,7 +398,8 @@ def __init__(self) -> None:
397398
# Create README for this run
398399
readme_path = os.path.join(self.kernels_dir, "README.md")
399400
with open(readme_path, "w") as f:
400-
f.write(f"""# Generated Kernels - {timestamp}
401+
f.write(
402+
f"""# Generated Kernels - {timestamp}
401403
402404
This directory contains PyTorch/Triton kernels generated by the LLM Backend.
403405
@@ -413,7 +415,8 @@ def __init__(self) -> None:
413415
414416
## Usage
415417
You can inspect these files to debug kernel generation, manually test implementations, or understand what the LLM produced.
416-
""")
418+
"""
419+
)
417420

418421
print(f"Saving generated kernels to: {self.kernels_dir}")
419422

@@ -521,8 +524,8 @@ def test_kernel_correctness(
521524
f.write(full_code)
522525
print(f"Saved kernel to: {kernel_file}")
523526

524-
import sys
525527
import importlib.util
528+
import sys
526529

527530
spec = importlib.util.spec_from_file_location(
528531
f"test_kernel_{op_name}_{attempt}", kernel_file
@@ -633,7 +636,8 @@ def __init__(self) -> None:
633636
# Create README for this run
634637
readme_path = os.path.join(self.kernels_dir, "README.md")
635638
with open(readme_path, "w") as f:
636-
f.write(f"""# Generated Kernels - KernelAgent - {timestamp}
639+
f.write(
640+
f"""# Generated Kernels - KernelAgent - {timestamp}
637641
638642
This directory contains PyTorch/Triton kernels generated by the KernelAgent Backend.
639643
@@ -656,7 +660,8 @@ def __init__(self) -> None:
656660
## Usage
657661
You can inspect these files to debug kernel generation, analyze the parallel worker outputs,
658662
or understand the sophisticated generation process used by KernelAgent.
659-
""")
663+
"""
664+
)
660665

661666
print(f"Saving KernelAgent generated kernels to: {self.kernels_dir}")
662667

@@ -688,7 +693,9 @@ def _get_kernel_agent(self):
688693
os.makedirs(agent_log_dir, exist_ok=True)
689694

690695
self.kernel_agent = TritonKernelAgent(
691-
log_dir=agent_log_dir, num_workers=self.num_workers, max_rounds=self.max_rounds
696+
log_dir=agent_log_dir,
697+
num_workers=self.num_workers,
698+
max_rounds=self.max_rounds,
692699
)
693700

694701
print(f"✓ KernelAgent initialized with log directory: {agent_log_dir}")

test/test_backends.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
1+
from unittest.mock import Mock, patch
2+
13
import pytest
24
import torch
3-
from unittest.mock import Mock, patch
4-
from BackendBench.backends import AtenBackend, FlagGemsBackend, LLMBackend, KernelAgentBackend
5+
from BackendBench.backends import (
6+
AtenBackend,
7+
FlagGemsBackend,
8+
KernelAgentBackend,
9+
LLMBackend,
10+
)
511

612
try:
713
import importlib.util
@@ -11,9 +17,9 @@
1117
HAS_FLAG_GEMS = False
1218

1319
try:
14-
import sys
15-
import os
1620
import importlib.util
21+
import os
22+
import sys
1723

1824
kernel_agent_path = os.path.join(os.path.dirname(__file__), "..", "KernelAgent")
1925
sys.path.insert(0, os.path.abspath(kernel_agent_path))
@@ -159,35 +165,38 @@ def generated_kernel(x):
159165
class TestKernelAgentBackend:
160166
@pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available")
161167
def test_kernel_agent_backend_initialization(self):
162-
with patch("os.makedirs"):
163-
backend = KernelAgentBackend()
164-
assert backend.name == "kernel_agent"
165-
assert "kernel_agent_run_" in backend.kernels_dir
166-
assert backend.num_workers == 4 # default value
167-
assert backend.max_rounds == 10 # default value
168+
backend = KernelAgentBackend()
169+
assert backend.name == "kernel_agent"
170+
assert "kernel_agent_run_" in backend.kernels_dir
171+
assert backend.num_workers == 4 # default value
172+
assert backend.max_rounds == 10 # default value
168173

169174
@pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available")
170175
def test_kernel_agent_backend_set_config(self):
171-
with patch("os.makedirs"):
172-
backend = KernelAgentBackend()
176+
backend = KernelAgentBackend()
173177

174-
backend.set_config(num_workers=8, max_rounds=20)
178+
backend.set_config(num_workers=8, max_rounds=20)
175179

176-
assert backend.num_workers == 8
177-
assert backend.max_rounds == 20
180+
assert backend.num_workers == 8
181+
assert backend.max_rounds == 20
178182

179183
@pytest.mark.skipif(not HAS_KERNEL_AGENT, reason="KernelAgent not available")
180184
def test_kernel_agent_backend_generate_kernel(self):
181185
with (
182-
patch("os.makedirs"),
183186
patch("triton_kernel_agent.TritonKernelAgent") as mock_kernel_agent_class,
184187
):
185188
backend = KernelAgentBackend()
186189

187190
mock_agent = Mock()
188191
mock_kernel_agent_class.return_value = mock_agent
189192

190-
mock_agent.generate_kernel.return_value = (True, "def kernel(): pass")
193+
mock_agent.generate_kernel.return_value = {
194+
"success": True,
195+
"kernel_code": "def kernel(): pass",
196+
"rounds": 1,
197+
"session_dir": "test_session_dir",
198+
"worker_id": 0,
199+
}
191200

192201
mock_op = Mock()
193202
mock_op.__str__ = Mock(return_value="test_op")
@@ -205,9 +214,8 @@ def test_backend_polymorphism(self):
205214
backends.append(AtenBackend())
206215
with patch("BackendBench.backends.flag_gems"):
207216
backends.append(FlagGemsBackend())
208-
with patch("os.makedirs"):
209-
backends.append(LLMBackend())
210-
backends.append(KernelAgentBackend())
217+
backends.append(LLMBackend())
218+
backends.append(KernelAgentBackend())
211219
for backend in backends:
212220
assert hasattr(backend, "name")
213221
assert hasattr(backend, "__contains__")

0 commit comments

Comments
 (0)