1313from contextlib import contextmanager
1414import time
1515import json
16+ import numpy as np
1617
1718"""
1819Acknowledgement: We introduce evaluation method in https://github.com/ScalingIntelligence/KernelBench to enhance function.
@@ -53,7 +54,6 @@ def synchronize(self):
5354registry_backend = {
5455 "inductor" : InductorBackend (),
5556 "tensorrt" : TensorRTBackend (),
56- "default" : InductorBackend (),
5757}
5858
5959
@@ -115,7 +115,7 @@ def naive_timer(duration_box, synchronizer_func):
115115
116116
117117def time_execution_naive (
118- model_call , synchronizer_func , num_warmup : int = 3 , num_trials : int = 100
118+ model_call , synchronizer_func , num_warmup : int = 3 , num_trials : int = 10
119119):
120120 print (f"[Profiling] Using device: CPU, warm up { num_warmup } , trials { num_trials } " )
121121 for _ in range (num_warmup ):
@@ -137,7 +137,6 @@ def get_timing_stats_cpu(elapsed_times: list[float]):
137137 "std" : float (f"{ np .std (elapsed_times ):.3g} " ),
138138 "min" : float (f"{ np .min (elapsed_times ):.3g} " ),
139139 "max" : float (f"{ np .max (elapsed_times ):.3g} " ),
140- "num_trials" : len (elapsed_times ),
141140 }
142141 return stats
143142
@@ -148,8 +147,24 @@ def test_single_model(args):
148147 model = get_model (args )
149148 compiled_model = compiler (model )
150149
151- eager_time_ms = - 1
152- compiled_time_ms = - 1
150+ eager_stats = {}
151+ compiled_stats = {}
152+
153+ result_data = {
154+ "configuration" : {
155+ "model" : os .path .basename (os .path .normpath (args .model_path )),
156+ "compiler" : args .compiler ,
157+ "device" : args .device ,
158+ "warmup" : args .warmup ,
159+ "trials" : args .trials ,
160+ },
161+ "correctness" : {},
162+ "performance" : {
163+ "eager" : {},
164+ "compiled" : {},
165+ "speedup" : {},
166+ },
167+ }
153168
154169 eager_model_call = lambda : model (** input_dict )
155170 compiled_model_call = lambda : compiled_model (** input_dict )
@@ -166,7 +181,6 @@ def test_single_model(args):
166181 device = torch .device ("cuda:0" ),
167182 )
168183 eager_stats = get_timing_stats (eager_times )
169- eager_time_ms = eager_stats ["mean" ]
170184
171185 compiled_times = time_execution_with_cuda_event (
172186 compiled_model_call ,
@@ -175,7 +189,6 @@ def test_single_model(args):
175189 device = torch .device ("cuda:0" ),
176190 )
177191 compiled_stats = get_timing_stats (compiled_times )
178- compiled_time_ms = compiled_stats ["mean" ]
179192 else :
180193 eager_times = time_execution_naive (
181194 eager_model_call ,
@@ -184,7 +197,6 @@ def test_single_model(args):
184197 num_trials = args .trials ,
185198 )
186199 eager_stats = get_timing_stats_cpu (eager_times )
187- eager_time_ms = eager_stats ["mean" ]
188200
189201 compiled_times = time_execution_naive (
190202 compiled_model_call ,
@@ -193,37 +205,80 @@ def test_single_model(args):
193205 num_trials = args .trials ,
194206 )
195207 compiled_stats = get_timing_stats_cpu (compiled_times )
196- compiled_time_ms = compiled_stats ["mean" ]
197208
198209 expected_out = eager_model_call ()
199210 compiled_out = compiled_model_call ()
200211
201- def print_cmp (key , func , ** kwargs ):
212+ def print_and_store_cmp (key , func , ** kwargs ):
202213 cmp_ret = func (expected_out , compiled_out , ** kwargs )
214+ result_data ["correctness" ][key ] = cmp_ret
203215 print (
204216 f"{ args .log_prompt } { key } model_path:{ args .model_path } { cmp_ret } " ,
205217 file = sys .stderr ,
206218 )
207219
208- print_cmp ("cmp.equal" , get_cmp_equal )
209- print_cmp ("cmp.all_close_atol8_rtol8" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-8 )
210- print_cmp ("cmp.all_close_atol8_rtol5" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-5 )
211- print_cmp ("cmp.all_close_atol5_rtol5" , get_cmp_all_close , atol = 1e-5 , rtol = 1e-5 )
212- print_cmp ("cmp.all_close_atol3_rtol2" , get_cmp_all_close , atol = 1e-3 , rtol = 1e-2 )
213- print_cmp ("cmp.all_close_atol2_rtol1" , get_cmp_all_close , atol = 1e-2 , rtol = 1e-1 )
214- print_cmp ("cmp.max_diff" , get_cmp_max_diff )
215- print_cmp ("cmp.mean_diff" , get_cmp_mean_diff )
216- print_cmp ("cmp.diff_count_atol8_rtol8" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-8 )
217- print_cmp ("cmp.diff_count_atol8_rtol5" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-5 )
218- print_cmp ("cmp.diff_count_atol5_rtol5" , get_cmp_diff_count , atol = 1e-5 , rtol = 1e-5 )
219- print_cmp ("cmp.diff_count_atol3_rtol2" , get_cmp_diff_count , atol = 1e-3 , rtol = 1e-2 )
220- print_cmp ("cmp.diff_count_atol2_rtol1" , get_cmp_diff_count , atol = 1e-2 , rtol = 1e-1 )
220+ print_and_store_cmp ("equal" , get_cmp_equal )
221+ print_and_store_cmp (
222+ "all_close_atol8_rtol8" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-8
223+ )
224+ print_and_store_cmp (
225+ "all_close_atol8_rtol5" , get_cmp_all_close , atol = 1e-8 , rtol = 1e-5
226+ )
227+ print_and_store_cmp (
228+ "all_close_atol5_rtol5" , get_cmp_all_close , atol = 1e-5 , rtol = 1e-5
229+ )
230+ print_and_store_cmp (
231+ "all_close_atol3_rtol2" , get_cmp_all_close , atol = 1e-3 , rtol = 1e-2
232+ )
233+ print_and_store_cmp (
234+ "all_close_atol2_rtol1" , get_cmp_all_close , atol = 1e-2 , rtol = 1e-1
235+ )
236+ print_and_store_cmp ("max_diff" , get_cmp_max_diff )
237+ print_and_store_cmp ("mean_diff" , get_cmp_mean_diff )
238+ print_and_store_cmp (
239+ "diff_count_atol8_rtol8" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-8
240+ )
241+ print_and_store_cmp (
242+ "diff_count_atol8_rtol5" , get_cmp_diff_count , atol = 1e-8 , rtol = 1e-5
243+ )
244+ print_and_store_cmp (
245+ "diff_count_atol5_rtol5" , get_cmp_diff_count , atol = 1e-5 , rtol = 1e-5
246+ )
247+ print_and_store_cmp (
248+ "diff_count_atol3_rtol2" , get_cmp_diff_count , atol = 1e-3 , rtol = 1e-2
249+ )
250+ print_and_store_cmp (
251+ "diff_count_atol2_rtol1" , get_cmp_diff_count , atol = 1e-2 , rtol = 1e-1
252+ )
253+
254+ eager_time_ms = eager_stats ["mean" ]
255+ compiled_time_ms = compiled_stats ["mean" ]
256+
257+ result_data ["performance" ]["eager" ]["mean" ] = eager_stats ["mean" ]
258+ result_data ["performance" ]["eager" ]["std" ] = eager_stats ["std" ]
259+ result_data ["performance" ]["eager" ]["min" ] = eager_stats ["min" ]
260+ result_data ["performance" ]["eager" ]["max" ] = eager_stats ["max" ]
261+ result_data ["performance" ]["compiled" ]["mean" ] = compiled_stats ["mean" ]
262+ result_data ["performance" ]["compiled" ]["std" ] = compiled_stats ["std" ]
263+ result_data ["performance" ]["compiled" ]["min" ] = compiled_stats ["min" ]
264+ result_data ["performance" ]["compiled" ]["max" ] = compiled_stats ["max" ]
265+ if eager_time_ms > 0 and compiled_time_ms > 0 :
266+ result_data ["performance" ]["speedup" ] = eager_time_ms / compiled_time_ms
221267
222268 print (
223269 f"{ args .log_prompt } duration model_path:{ args .model_path } eager:{ eager_time_ms :.4f} compiled:{ compiled_time_ms :.4f} " ,
224270 file = sys .stderr ,
225271 )
226272
273+ if args .output_dir :
274+ os .makedirs (args .output_dir , exist_ok = True )
275+ model_name = result_data ["configuration" ]["model" ]
276+ compiler_name = args .compiler
277+ file_path = os .path .join (args .output_dir , f"{ model_name } _{ compiler_name } .json" )
278+ with open (file_path , "w" ) as f :
279+ json .dump (result_data , f , indent = 4 )
280+ print (f"Result saved to { file_path } " , file = sys .stderr )
281+
227282
228283def get_cmp_equal (expected_out , compiled_out ):
229284 return " " .join (
@@ -261,18 +316,27 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
261316
262317def test_multi_models (args ):
263318 for model_path in get_recursively_model_path (args .model_path ):
264- cmd = "" .join (
265- [
266- sys .executable ,
267- " -m graph_net.torch.test_compiler" ,
268- f" --model-path { model_path } " ,
269- f" --compiler { args .compiler } " ,
270- f" --warmup { args .warmup } " ,
271- f" --trials { args .trials } " ,
272- f" --log-prompt { args .log_prompt } " ,
273- f" --device { args .device } " ,
274- ]
275- )
319+ cmd_list = [
320+ sys .executable ,
321+ "-m" ,
322+ "graph_net.torch.test_compiler" ,
323+ "--model-path" ,
324+ model_path ,
325+ "--compiler" ,
326+ args .compiler ,
327+ "--warmup" ,
328+ str (args .warmup ),
329+ "--trials" ,
330+ str (args .trials ),
331+ "--log-prompt" ,
332+ args .log_prompt ,
333+ "--device" ,
334+ args .device ,
335+ ]
336+ if args .output_dir :
337+ cmd_list .extend (["--output-dir" , args .output_dir ])
338+
339+ cmd = " " .join (cmd_list )
276340 cmd_ret = os .system (cmd )
277341 assert cmd_ret == 0 , f"{ cmd_ret = } , { cmd = } "
278342
@@ -318,7 +382,7 @@ def main(args):
318382 "--compiler" ,
319383 type = str ,
320384 required = False ,
321- default = "default " ,
385+ default = "inductor " ,
322386 help = "Path to customized compiler python file" ,
323387 )
324388 parser .add_argument (
@@ -341,5 +405,12 @@ def main(args):
341405 default = "graph-net-test-compiler-log" ,
342406 help = "Log prompt for performance log filtering." ,
343407 )
408+ parser .add_argument (
409+ "--output-dir" ,
410+ type = str ,
411+ required = False ,
412+ default = None ,
413+ help = "Directory to save the structured JSON result file." ,
414+ )
344415 args = parser .parse_args ()
345416 main (args = args )
0 commit comments