1
+ from unittest .mock import Mock , patch
2
+
1
3
import pytest
2
4
import torch
3
- from unittest .mock import Mock , patch
4
- from BackendBench .backends import AtenBackend , FlagGemsBackend , LLMBackend , KernelAgentBackend
5
+ from BackendBench .backends import (
6
+ AtenBackend ,
7
+ FlagGemsBackend ,
8
+ KernelAgentBackend ,
9
+ LLMBackend ,
10
+ )
5
11
6
12
try :
7
13
import importlib .util
11
17
HAS_FLAG_GEMS = False
12
18
13
19
try :
14
- import sys
15
- import os
16
20
import importlib .util
21
+ import os
22
+ import sys
17
23
18
24
kernel_agent_path = os .path .join (os .path .dirname (__file__ ), ".." , "KernelAgent" )
19
25
sys .path .insert (0 , os .path .abspath (kernel_agent_path ))
@@ -159,35 +165,38 @@ def generated_kernel(x):
159
165
class TestKernelAgentBackend :
160
166
@pytest .mark .skipif (not HAS_KERNEL_AGENT , reason = "KernelAgent not available" )
161
167
def test_kernel_agent_backend_initialization (self ):
162
- with patch ("os.makedirs" ):
163
- backend = KernelAgentBackend ()
164
- assert backend .name == "kernel_agent"
165
- assert "kernel_agent_run_" in backend .kernels_dir
166
- assert backend .num_workers == 4 # default value
167
- assert backend .max_rounds == 10 # default value
168
+ backend = KernelAgentBackend ()
169
+ assert backend .name == "kernel_agent"
170
+ assert "kernel_agent_run_" in backend .kernels_dir
171
+ assert backend .num_workers == 4 # default value
172
+ assert backend .max_rounds == 10 # default value
168
173
169
174
@pytest .mark .skipif (not HAS_KERNEL_AGENT , reason = "KernelAgent not available" )
170
175
def test_kernel_agent_backend_set_config (self ):
171
- with patch ("os.makedirs" ):
172
- backend = KernelAgentBackend ()
176
+ backend = KernelAgentBackend ()
173
177
174
- backend .set_config (num_workers = 8 , max_rounds = 20 )
178
+ backend .set_config (num_workers = 8 , max_rounds = 20 )
175
179
176
- assert backend .num_workers == 8
177
- assert backend .max_rounds == 20
180
+ assert backend .num_workers == 8
181
+ assert backend .max_rounds == 20
178
182
179
183
@pytest .mark .skipif (not HAS_KERNEL_AGENT , reason = "KernelAgent not available" )
180
184
def test_kernel_agent_backend_generate_kernel (self ):
181
185
with (
182
- patch ("os.makedirs" ),
183
186
patch ("triton_kernel_agent.TritonKernelAgent" ) as mock_kernel_agent_class ,
184
187
):
185
188
backend = KernelAgentBackend ()
186
189
187
190
mock_agent = Mock ()
188
191
mock_kernel_agent_class .return_value = mock_agent
189
192
190
- mock_agent .generate_kernel .return_value = (True , "def kernel(): pass" )
193
+ mock_agent .generate_kernel .return_value = {
194
+ "success" : True ,
195
+ "kernel_code" : "def kernel(): pass" ,
196
+ "rounds" : 1 ,
197
+ "session_dir" : "test_session_dir" ,
198
+ "worker_id" : 0 ,
199
+ }
191
200
192
201
mock_op = Mock ()
193
202
mock_op .__str__ = Mock (return_value = "test_op" )
@@ -205,9 +214,8 @@ def test_backend_polymorphism(self):
205
214
backends .append (AtenBackend ())
206
215
with patch ("BackendBench.backends.flag_gems" ):
207
216
backends .append (FlagGemsBackend ())
208
- with patch ("os.makedirs" ):
209
- backends .append (LLMBackend ())
210
- backends .append (KernelAgentBackend ())
217
+ backends .append (LLMBackend ())
218
+ backends .append (KernelAgentBackend ())
211
219
for backend in backends :
212
220
assert hasattr (backend , "name" )
213
221
assert hasattr (backend , "__contains__" )
0 commit comments