@@ -43,6 +43,24 @@ def set_seed(random_seed):
4343 torch .cuda .manual_seed_all (random_seed )
4444
4545
46+ def get_hardward_name (args ):
47+ hardware_name = "unknown"
48+ if "cuda" in args .device :
49+ hardware_name = torch .cuda .get_device_name (args .device )
50+ elif args .device == "cpu" :
51+ hardware_name = platform .processor ()
52+ return hardware_name
53+
54+
55+ def get_compile_framework_version (args ):
56+ if args .compiler in ["inductor" , "nope" ]:
57+ return torch .__version__
58+ elif args .compiler in ["tvm" , "xla" , "tensorrt" , "bladedisc" ]:
59+ # Assuming compiler object has a version attribute
60+ return f"{ args .compiler .capitalize ()} { compiler .version } "
61+ return "unknown"
62+
63+
4664def load_class_from_file (
4765 args : argparse .Namespace , class_name : str , device : str
4866) -> Type [torch .nn .Module ]:
@@ -87,31 +105,6 @@ def get_input_dict(args):
87105 }
88106
89107
90- @dataclass
91- class DurationBox :
92- value : float
93-
94-
95- @contextmanager
96- def naive_timer (duration_box , synchronizer_func ):
97- synchronizer_func ()
98- start = time .time ()
99- yield
100- synchronizer_func ()
101- end = time .time ()
102- duration_box .value = (end - start ) * 1000 # Store in milliseconds
103-
104-
105- def get_timing_stats (elapsed_times : List [float ]):
106- stats = {
107- "mean" : float (f"{ np .mean (elapsed_times ):.3g} " ),
108- "std" : float (f"{ np .std (elapsed_times ):.3g} " ),
109- "min" : float (f"{ np .min (elapsed_times ):.3g} " ),
110- "max" : float (f"{ np .max (elapsed_times ):.3g} " ),
111- }
112- return stats
113-
114-
115108def measure_performance (model_call , args , compiler ):
116109 stats = {}
117110
@@ -120,27 +113,26 @@ def measure_performance(model_call, args, compiler):
120113 model_call ()
121114 compiler .synchronize ()
122115
116+ hardware_name = get_hardward_name (args )
117+ print (
118+ f"[Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } " ,
119+ file = sys .stderr ,
120+ flush = True ,
121+ )
122+
123123 if "cuda" in args .device :
124124 """
125125 Acknowledgement: We evaluate the performance on both end-to-end and GPU-only timings,
126126 With reference to methods only based on CUDA events from KernelBench in https://github.com/ScalingIntelligence/KernelBench
127127 """
128128
129- device = torch .device (args .device )
130- hardware_name = torch .cuda .get_device_name (device )
131- print (
132- f"{ args .log_prompt } [Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } " ,
133- file = sys .stderr ,
134- flush = True ,
135- )
136-
137129 e2e_times = []
138130 gpu_times = []
139131
140132 for i in range (args .trials ):
141133 # End-to-end timing (naive_timer)
142- duration_box = DurationBox (- 1 )
143- with naive_timer (duration_box , compiler .synchronize ):
134+ duration_box = test_compiler_util . DurationBox (- 1 )
135+ with test_compiler_util . naive_timer (duration_box , compiler .synchronize ):
144136 # GPU-only timing (CUDA Events)
145137 start_event = torch .cuda .Event (enable_timing = True )
146138 end_event = torch .cuda .Event (enable_timing = True )
@@ -149,7 +141,7 @@ def measure_performance(model_call, args, compiler):
149141 model_call ()
150142
151143 end_event .record ()
152- torch . cuda . synchronize (device = device )
144+ compiler . synchronize ()
153145
154146 gpu_time_ms = start_event .elapsed_time (end_event )
155147 e2e_times .append (duration_box .value )
@@ -160,29 +152,22 @@ def measure_performance(model_call, args, compiler):
160152 flush = True ,
161153 )
162154
163- stats ["e2e" ] = get_timing_stats (e2e_times )
164- stats ["gpu" ] = get_timing_stats (gpu_times )
155+ stats ["e2e" ] = test_compiler_util . get_timing_stats (e2e_times )
156+ stats ["gpu" ] = test_compiler_util . get_timing_stats (gpu_times )
165157
166158 else : # CPU or other devices
167- hardware_name = platform .processor ()
168- print (
169- f"[Profiling] Using device: { args .device } { hardware_name } , warm up { args .warmup } , trials { args .trials } " ,
170- file = sys .stderr ,
171- flush = True ,
172- )
173-
174159 e2e_times = []
175160 for i in range (args .trials ):
176- duration_box = DurationBox (- 1 )
177- with naive_timer (duration_box , compiler .synchronize ):
161+ duration_box = test_compiler_util . DurationBox (- 1 )
162+ with test_compiler_util . naive_timer (duration_box , compiler .synchronize ):
178163 model_call ()
179164 print (
180165 f"Trial { i + 1 } : e2e={ duration_box .value :.5f} ms" ,
181166 file = sys .stderr ,
182167 flush = True ,
183168 )
184169 e2e_times .append (duration_box .value )
185- stats ["e2e" ] = get_timing_stats (e2e_times )
170+ stats ["e2e" ] = test_compiler_utilget_timing_stats (e2e_times )
186171
187172 return stats
188173
@@ -191,49 +176,9 @@ def test_single_model(args):
191176 compiler = get_compiler_backend (args )
192177 input_dict = get_input_dict (args )
193178 model = get_model (args , args .device )
194- model_path = os .path .normpath (args .model_path )
195- print (f"{ args .log_prompt } [Processing] { model_path } " , file = sys .stderr , flush = True )
196- model_name = os .path .basename (model_path )
197- print (
198- f"{ args .log_prompt } [Config] model: { model_name } " , file = sys .stderr , flush = True
199- )
200- print (
201- f"{ args .log_prompt } [Config] device: { args .device } " , file = sys .stderr , flush = True
202- )
203-
204- hardware_name = "unknown"
205- if "cuda" in args .device :
206- hardware_name = torch .cuda .get_device_name (args .device )
207- elif args .device == "cpu" :
208- hardware_name = platform .processor ()
209- print (
210- f"{ args .log_prompt } [Config] hardware: { hardware_name } " ,
211- file = sys .stderr ,
212- flush = True ,
213- )
214-
215- print (
216- f"{ args .log_prompt } [Config] compiler: { args .compiler } " ,
217- file = sys .stderr ,
218- flush = True ,
219- )
220- print (
221- f"{ args .log_prompt } [Config] warmup: { args .warmup } " , file = sys .stderr , flush = True
222- )
223- print (
224- f"{ args .log_prompt } [Config] trials: { args .trials } " , file = sys .stderr , flush = True
225- )
226179
227- version_str = "unknown"
228- if args .compiler == "inductor" :
229- version_str = torch .__version__
230- elif args .compiler in ["tvm" , "xla" , "tensorrt" , "bladedisc" ]:
231- # Assuming compiler object has a version attribute
232- version_str = f"{ args .compiler .capitalize ()} { compiler .version } "
233- print (
234- f"{ args .log_prompt } [Config] compile_framework_version: { version_str } " ,
235- file = sys .stderr ,
236- flush = True ,
180+ test_compiler_util .print_basic_config (
181+ args , get_hardward_name (args ), get_compile_framework_version (args )
237182 )
238183
239184 runtime_seed = 1024
@@ -245,28 +190,11 @@ def test_single_model(args):
245190 try :
246191 eager_model_call = lambda : model (** input_dict )
247192 eager_stats = measure_performance (eager_model_call , args , compiler )
248- print (
249- f"{ args .log_prompt } [Performance][eager]: { json .dumps (eager_stats )} " ,
250- file = sys .stderr ,
251- flush = True ,
252- )
253193
254194 torch .manual_seed (runtime_seed )
255195 expected_out = eager_model_call ()
256196 if not isinstance (expected_out , tuple ):
257197 expected_out = (expected_out ,)
258-
259- eager_types = [
260- str (x .dtype ).replace ("torch." , "" )
261- if isinstance (x , torch .Tensor )
262- else type (x ).__name__
263- for x in expected_out
264- ]
265- print (
266- f"{ args .log_prompt } [Datatype][eager]: { ' ' .join (eager_types )} " ,
267- file = sys .stderr ,
268- flush = True ,
269- )
270198 except (TypeError , RuntimeError ) as e :
271199 print (f"Eager model execution failed: { str (e )} " , file = sys .stderr )
272200 eager_failure = True
@@ -286,43 +214,12 @@ def test_single_model(args):
286214 torch .manual_seed (runtime_seed )
287215 compiled_model_call = lambda : compiled_model (** input_dict )
288216 compiled_stats = measure_performance (compiled_model_call , args , compiler )
289- print (
290- f"{ args .log_prompt } [Performance][compiled]: { json .dumps (compiled_stats )} " ,
291- file = sys .stderr ,
292- flush = True ,
293- )
294217
295218 compiled_out = compiled_model_call ()
296219 if not isinstance (compiled_out , tuple ):
297220 compiled_out = (compiled_out ,)
298221 if args .compiler == "xla" :
299222 compiled_out = tuple (item .to ("cpu" ).to ("cuda" ) for item in compiled_out )
300-
301- compiled_types = [
302- str (x .dtype ).replace ("torch." , "" )
303- if isinstance (x , torch .Tensor )
304- else type (x ).__name__
305- for x in compiled_out
306- ]
307- print (
308- f"{ args .log_prompt } [Datatype][compiled]: { ' ' .join (compiled_types )} " ,
309- file = sys .stderr ,
310- flush = True ,
311- )
312-
313- # datatype check
314- type_match = all (
315- eager == compiled for eager , compiled in zip (eager_types , compiled_types )
316- )
317- print (
318- f"{ args .log_prompt } [DataType] eager:{ eager_types } compiled:{ compiled_types } match:{ type_match } " ,
319- file = sys .stderr ,
320- )
321- # "datatype not match" is recognized as a large loss in analysis process later,
322- # and is not recognized as a failure here.
323-
324- compare_correctness (expected_out , compiled_out , args )
325-
326223 except (TypeError , RuntimeError ) as e :
327224 print (f"Compiled model execution failed: { str (e )} " , file = sys .stderr )
328225 compiled_failure = True
@@ -342,39 +239,13 @@ def test_single_model(args):
342239 flush = True ,
343240 )
344241 else :
242+ compare_correctness (expected_out , compiled_out , args )
243+
345244 print (
346245 f"{ args .log_prompt } [Result] status: success" , file = sys .stderr , flush = True
347246 )
348247
349- e2e_speedup = 0
350- gpu_speedup = 0
351-
352- eager_e2e_time_ms = eager_stats .get ("e2e" , {}).get ("mean" , 0 )
353- compiled_e2e_time_ms = compiled_stats .get ("e2e" , {}).get ("mean" , 0 )
354-
355- if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0 :
356- e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms
357-
358- if "cuda" in args .device :
359- eager_gpu_time_ms = eager_stats .get ("gpu" , {}).get ("mean" , 0 )
360- compiled_gpu_time_ms = compiled_stats .get ("gpu" , {}).get ("mean" , 0 )
361-
362- if eager_gpu_time_ms > 0 and compiled_gpu_time_ms > 0 :
363- gpu_speedup = eager_gpu_time_ms / compiled_gpu_time_ms
364-
365- if e2e_speedup > 0 :
366- print (
367- f"{ args .log_prompt } [Speedup][e2e]: { e2e_speedup :.4f} " ,
368- file = sys .stderr ,
369- flush = True ,
370- )
371-
372- if "cuda" in args .device and gpu_speedup > 0 :
373- print (
374- f"{ args .log_prompt } [Speedup][gpu]: { gpu_speedup :.4f} " ,
375- file = sys .stderr ,
376- flush = True ,
377- )
248+ test_compiler_util .print_times_and_speedup (args , eager_stats , compiled_stats )
378249
379250
380251def print_and_store_cmp (key , cmp_func , args , expected_out , compiled_out , ** kwargs ):
@@ -388,22 +259,41 @@ def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwarg
388259
389260
390261def compare_correctness (expected_out , compiled_out , args ):
391- test_compiler_util .check_equal (
392- args ,
393- expected_out ,
394- compiled_out ,
395- cmp_equal_func = get_cmp_equal ,
396- )
262+ eager_dtypes = [
263+ str (x .dtype ).replace ("torch." , "" )
264+ if isinstance (x , torch .Tensor )
265+ else type (x ).__name__
266+ for x in expected_out
267+ ]
268+ compiled_dtypes = [
269+ str (x .dtype ).replace ("torch." , "" )
270+ if isinstance (x , torch .Tensor )
271+ else type (x ).__name__
272+ for x in compiled_out
273+ ]
397274
398- test_compiler_util .check_allclose (
399- args ,
400- expected_out ,
401- compiled_out ,
402- cmp_all_close_func = get_cmp_all_close ,
403- cmp_max_diff_func = get_cmp_max_diff ,
404- cmp_mean_diff_func = get_cmp_mean_diff ,
275+ # datatype check
276+ type_match = test_compiler_util .check_output_datatype (
277+ args , eager_dtypes , compiled_dtypes
405278 )
406279
280+ if type_match :
281+ test_compiler_util .check_equal (
282+ args ,
283+ expected_out ,
284+ compiled_out ,
285+ cmp_equal_func = get_cmp_equal ,
286+ )
287+
288+ test_compiler_util .check_allclose (
289+ args ,
290+ expected_out ,
291+ compiled_out ,
292+ cmp_all_close_func = get_cmp_all_close ,
293+ cmp_max_diff_func = get_cmp_max_diff ,
294+ cmp_mean_diff_func = get_cmp_mean_diff ,
295+ )
296+
407297
408298def get_cmp_equal (expected_out , compiled_out ):
409299 return " " .join (
0 commit comments