Skip to content

Commit 115a865

Browse files
authored
Use LLM/KernelAgent backends for any op suite (#63)
1 parent 6dfe2be commit 115a865

File tree

1 file changed

+24
-38
lines changed

1 file changed

+24
-38
lines changed

BackendBench/scripts/main.py

Lines changed: 24 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,12 @@ def setup_logging(log_level):
8484
type=str,
8585
help="Path to TorchBench operator data",
8686
)
87+
@click.option(
88+
"--ops-directory",
89+
default="generated_kernels",
90+
type=str,
91+
help="Path to directory containing generated kernels",
92+
)
8793
def cli(
8894
log_level,
8995
suite,
@@ -94,6 +100,7 @@ def cli(
94100
kernel_agent_workers,
95101
kernel_agent_max_rounds,
96102
torchbench_data_path,
103+
ops_directory,
97104
):
98105
setup_logging(log_level)
99106
if ops:
@@ -107,17 +114,6 @@ def cli(
107114
"directory": backends.DirectoryBackend,
108115
}[backend]()
109116

110-
# For LLM backend, we need to generate kernels first
111-
if backend.name == "llm":
112-
llm_client = ClaudeKernelGenerator()
113-
backend = setup_llm_backend(backend, llm_client, suite, ops, llm_max_attempts)
114-
115-
# For KernelAgent backend, we need to generate kernels using the sophisticated agent system
116-
elif backend.name == "kernel_agent":
117-
backend = setup_kernel_agent_backend(
118-
backend, suite, ops, kernel_agent_workers, kernel_agent_max_rounds
119-
)
120-
121117
suite = {
122118
"smoke": lambda: SmokeTestSuite,
123119
"opinfo": lambda: OpInfoTestSuite(
@@ -134,6 +130,21 @@ def cli(
134130
),
135131
}[suite]()
136132

133+
# For LLM backend, we need to generate kernels first
134+
if backend.name == "llm":
135+
llm_client = ClaudeKernelGenerator()
136+
backend = setup_llm_backend(backend, llm_client, suite, llm_max_attempts)
137+
138+
# For KernelAgent backend, we need to generate kernels using the sophisticated agent system
139+
elif backend.name == "kernel_agent":
140+
backend = setup_kernel_agent_backend(
141+
backend, suite, kernel_agent_workers, kernel_agent_max_rounds
142+
)
143+
144+
# For Directory backend, we need to load existing kernels from a directory
145+
elif backend.name == "directory":
146+
backend = backends.DirectoryBackend(ops_directory)
147+
137148
overall_correctness = []
138149
overall_performance = []
139150

@@ -160,21 +171,9 @@ def cli(
160171
print(f"performance score (geomean speedup over all operators): {geomean_perf:.2f}")
161172

162173

163-
def setup_llm_backend(llm_backend, llm_client, suite_name, ops_filter, max_attempts=5):
174+
def setup_llm_backend(llm_backend, llm_client, suite, max_attempts=5):
164175
"""Setup LLM backend by generating kernels for all operations in the suite."""
165176
try:
166-
if suite_name == "smoke":
167-
suite = SmokeTestSuite
168-
elif suite_name == "opinfo":
169-
suite = OpInfoTestSuite(
170-
"opinfo_cuda_bfloat16",
171-
"cuda",
172-
torch.bfloat16,
173-
filter=ops_filter,
174-
)
175-
else:
176-
raise ValueError(f"Unknown suite: {suite_name}")
177-
178177
successful_ops = 0
179178
total_ops = 0
180179

@@ -287,24 +286,11 @@ def feedback_callback(kernel_code: str, attempt: int) -> tuple[bool, Dict]:
287286
sys.exit(1)
288287

289288

290-
def setup_kernel_agent_backend(
291-
kernel_agent_backend, suite_name, ops_filter, num_workers=4, max_rounds=10
292-
):
289+
def setup_kernel_agent_backend(kernel_agent_backend, suite, num_workers=4, max_rounds=10):
293290
"""Setup KernelAgent backend by generating kernels using the sophisticated agent system."""
294291
try:
295292
# Configure the backend with the specified parameters
296293
kernel_agent_backend.set_config(num_workers, max_rounds)
297-
if suite_name == "smoke":
298-
suite = SmokeTestSuite
299-
elif suite_name == "opinfo":
300-
suite = OpInfoTestSuite(
301-
"opinfo_cuda_bfloat16",
302-
"cuda",
303-
torch.bfloat16,
304-
filter=ops_filter,
305-
)
306-
else:
307-
raise ValueError(f"Unknown suite: {suite_name}")
308294

309295
successful_ops = 0
310296
total_ops = 0

0 commit comments

Comments
 (0)