@@ -84,6 +84,12 @@ def setup_logging(log_level):
84
84
type = str ,
85
85
help = "Path to TorchBench operator data" ,
86
86
)
87
+ @click .option (
88
+ "--ops-directory" ,
89
+ default = "generated_kernels" ,
90
+ type = str ,
91
+ help = "Path to directory containing generated kernels" ,
92
+ )
87
93
def cli (
88
94
log_level ,
89
95
suite ,
@@ -94,6 +100,7 @@ def cli(
94
100
kernel_agent_workers ,
95
101
kernel_agent_max_rounds ,
96
102
torchbench_data_path ,
103
+ ops_directory ,
97
104
):
98
105
setup_logging (log_level )
99
106
if ops :
@@ -107,17 +114,6 @@ def cli(
107
114
"directory" : backends .DirectoryBackend ,
108
115
}[backend ]()
109
116
110
- # For LLM backend, we need to generate kernels first
111
- if backend .name == "llm" :
112
- llm_client = ClaudeKernelGenerator ()
113
- backend = setup_llm_backend (backend , llm_client , suite , ops , llm_max_attempts )
114
-
115
- # For KernelAgent backend, we need to generate kernels using the sophisticated agent system
116
- elif backend .name == "kernel_agent" :
117
- backend = setup_kernel_agent_backend (
118
- backend , suite , ops , kernel_agent_workers , kernel_agent_max_rounds
119
- )
120
-
121
117
suite = {
122
118
"smoke" : lambda : SmokeTestSuite ,
123
119
"opinfo" : lambda : OpInfoTestSuite (
@@ -134,6 +130,21 @@ def cli(
134
130
),
135
131
}[suite ]()
136
132
133
+ # For LLM backend, we need to generate kernels first
134
+ if backend .name == "llm" :
135
+ llm_client = ClaudeKernelGenerator ()
136
+ backend = setup_llm_backend (backend , llm_client , suite , llm_max_attempts )
137
+
138
+ # For KernelAgent backend, we need to generate kernels using the sophisticated agent system
139
+ elif backend .name == "kernel_agent" :
140
+ backend = setup_kernel_agent_backend (
141
+ backend , suite , kernel_agent_workers , kernel_agent_max_rounds
142
+ )
143
+
144
+ # For Directory backend, we need to load existing kernels from a directory
145
+ elif backend .name == "directory" :
146
+ backend = backends .DirectoryBackend (ops_directory )
147
+
137
148
overall_correctness = []
138
149
overall_performance = []
139
150
@@ -160,21 +171,9 @@ def cli(
160
171
print (f"performance score (geomean speedup over all operators): { geomean_perf :.2f} " )
161
172
162
173
163
- def setup_llm_backend (llm_backend , llm_client , suite_name , ops_filter , max_attempts = 5 ):
174
+ def setup_llm_backend (llm_backend , llm_client , suite , max_attempts = 5 ):
164
175
"""Setup LLM backend by generating kernels for all operations in the suite."""
165
176
try :
166
- if suite_name == "smoke" :
167
- suite = SmokeTestSuite
168
- elif suite_name == "opinfo" :
169
- suite = OpInfoTestSuite (
170
- "opinfo_cuda_bfloat16" ,
171
- "cuda" ,
172
- torch .bfloat16 ,
173
- filter = ops_filter ,
174
- )
175
- else :
176
- raise ValueError (f"Unknown suite: { suite_name } " )
177
-
178
177
successful_ops = 0
179
178
total_ops = 0
180
179
@@ -287,24 +286,11 @@ def feedback_callback(kernel_code: str, attempt: int) -> tuple[bool, Dict]:
287
286
sys .exit (1 )
288
287
289
288
290
- def setup_kernel_agent_backend (
291
- kernel_agent_backend , suite_name , ops_filter , num_workers = 4 , max_rounds = 10
292
- ):
289
+ def setup_kernel_agent_backend (kernel_agent_backend , suite , num_workers = 4 , max_rounds = 10 ):
293
290
"""Setup KernelAgent backend by generating kernels using the sophisticated agent system."""
294
291
try :
295
292
# Configure the backend with the specified parameters
296
293
kernel_agent_backend .set_config (num_workers , max_rounds )
297
- if suite_name == "smoke" :
298
- suite = SmokeTestSuite
299
- elif suite_name == "opinfo" :
300
- suite = OpInfoTestSuite (
301
- "opinfo_cuda_bfloat16" ,
302
- "cuda" ,
303
- torch .bfloat16 ,
304
- filter = ops_filter ,
305
- )
306
- else :
307
- raise ValueError (f"Unknown suite: { suite_name } " )
308
294
309
295
successful_ops = 0
310
296
total_ops = 0
0 commit comments