@@ -223,8 +223,7 @@ def test_single_model(args):
223223 else :
224224 result_data ["configuration" ]["compiler_version" ] = "unknown"
225225
226- execution_failure = False
227- correctness_failure = False
226+ failure = False
228227
229228 try :
230229 eager_model_call = lambda : model (** input_dict )
@@ -267,31 +266,19 @@ def test_single_model(args):
267266 file = sys .stderr ,
268267 )
269268 if not type_match :
270- correctness_failure = True
269+ failure = True
271270 else :
272- # correctness check according to datatype
273- correctness_failure = compare_correctness (
274- expected_out , compiled_out , result_data , args
275- )
271+ compare_correctness (expected_out , compiled_out , result_data , args )
276272 except (TypeError , RuntimeError ) as e :
277273 print (f"Model execution failed: { str (e )} " , file = sys .stderr )
278- execution_failure = True
274+ failure = True
279275
280- penalty = 5
281276 e2e_speedup = 0
282277 gpu_speedup = 0
283- if execution_failure :
284- e2e_speedup = 1 / (2 ** penalty )
285- result_data ["performance" ]["speedup" ]["e2e" ] = e2e_speedup
278+ if failure :
279+ result_data ["performance" ]["fail" ] = "True"
286280 print (
287- f"{ args .log_prompt } [Execution Fail][Panelty Speedup] e2e_speedup:{ e2e_speedup :.4f} " ,
288- file = sys .stderr ,
289- )
290- elif correctness_failure :
291- e2e_speedup = 1 / (2 ** penalty )
292- result_data ["performance" ]["speedup" ]["e2e" ] = e2e_speedup
293- print (
294- f"{ args .log_prompt } [Correctness Fail][Panelty Speedup] e2e_speedup:{ e2e_speedup :.4f} " ,
281+ f"{ args .log_prompt } [Fail due to compile error or datatype do not match." ,
295282 file = sys .stderr ,
296283 )
297284 else :
@@ -307,8 +294,7 @@ def test_single_model(args):
307294 f"eager_e2e:{ eager_e2e_time_ms :.4f} compiled_e2e:{ compiled_e2e_time_ms :.4f} "
308295 )
309296 speedup_log = (
310- f"{ args .log_prompt } [Success][Speedup] "
311- f"e2e_speedup:{ e2e_speedup :.4f} "
297+ f"{ args .log_prompt } [Speedup] " f"e2e_speedup:{ e2e_speedup :.4f} "
312298 )
313299
314300 if "cuda" in args .device :
@@ -378,34 +364,6 @@ def compare_correctness(expected_out, compiled_out, result_data, args):
378364 eager_types = result_data ["performance" ]["datatype" ]["eager" ]
379365 compiled_types = result_data ["performance" ]["datatype" ]["compiled" ]
380366
381- def _pick_key (dtype ):
382- if dtype in ("torch.float64" , "torch.double" ):
383- return "[all_close_atol8_rtol5]"
384- if dtype in ("torch.float32" , "torch.float" ):
385- return "[all_close_atol8_rtol5]"
386- if dtype in ("torch.float16" , "torch.bfloat16" ):
387- return "[all_close_atol3_rtol2]"
388- # float8
389- if dtype in ("torch.float8_e5m2" , "torch.float8_e4m3fn" ):
390- return "[all_close_atol2_rtol1]"
391- # int / bool
392- if "int" in dtype or dtype == "torch.bool" :
393- return "[equal]"
394- # complex
395- if dtype in ("torch.complex64" , "torch.complex128" ):
396- return "[all_close_atol8_rtol5]"
397- # default
398- return "[all_close_atol8_rtol5]"
399-
400- for idx in range (len (compiled_out )):
401- dtype = compiled_types [idx ]
402- cmp_str = result_data ["correctness" ].get (_pick_key (dtype ), "" )
403- tokens = cmp_str .split ()
404- if idx >= len (tokens ) or tokens [idx ] != "1" :
405- return True
406-
407- return False
408-
409367
410368def get_cmp_equal (expected_out , compiled_out ):
411369 return " " .join (
0 commit comments