Skip to content

Commit 99d0dcf

Browse files
committed
Del panelty
1 parent ea0278b commit 99d0dcf

File tree

1 file changed

+8
-50
lines changed

1 file changed

+8
-50
lines changed

graph_net/torch/test_compiler.py

Lines changed: 8 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,7 @@ def test_single_model(args):
223223
else:
224224
result_data["configuration"]["compiler_version"] = "unknown"
225225

226-
execution_failure = False
227-
correctness_failure = False
226+
failure = False
228227

229228
try:
230229
eager_model_call = lambda: model(**input_dict)
@@ -267,31 +266,19 @@ def test_single_model(args):
267266
file=sys.stderr,
268267
)
269268
if not type_match:
270-
correctness_failure = True
269+
failure = True
271270
else:
272-
# correctness check according to datatype
273-
correctness_failure = compare_correctness(
274-
expected_out, compiled_out, result_data, args
275-
)
271+
compare_correctness(expected_out, compiled_out, result_data, args)
276272
except (TypeError, RuntimeError) as e:
277273
print(f"Model execution failed: {str(e)}", file=sys.stderr)
278-
execution_failure = True
274+
failure = True
279275

280-
penalty = 5
281276
e2e_speedup = 0
282277
gpu_speedup = 0
283-
if execution_failure:
284-
e2e_speedup = 1 / (2**penalty)
285-
result_data["performance"]["speedup"]["e2e"] = e2e_speedup
278+
if failure:
279+
result_data["performance"]["fail"] = "True"
286280
print(
287-
f"{args.log_prompt} [Execution Fail][Panelty Speedup] e2e_speedup:{e2e_speedup:.4f}",
288-
file=sys.stderr,
289-
)
290-
elif correctness_failure:
291-
e2e_speedup = 1 / (2**penalty)
292-
result_data["performance"]["speedup"]["e2e"] = e2e_speedup
293-
print(
294-
f"{args.log_prompt} [Correctness Fail][Panelty Speedup] e2e_speedup:{e2e_speedup:.4f}",
281+
f"{args.log_prompt} [Fail due to compile error or datatype do not match.",
295282
file=sys.stderr,
296283
)
297284
else:
@@ -307,8 +294,7 @@ def test_single_model(args):
307294
f"eager_e2e:{eager_e2e_time_ms:.4f} compiled_e2e:{compiled_e2e_time_ms:.4f}"
308295
)
309296
speedup_log = (
310-
f"{args.log_prompt} [Success][Speedup] "
311-
f"e2e_speedup:{e2e_speedup:.4f}"
297+
f"{args.log_prompt} [Speedup] " f"e2e_speedup:{e2e_speedup:.4f}"
312298
)
313299

314300
if "cuda" in args.device:
@@ -378,34 +364,6 @@ def compare_correctness(expected_out, compiled_out, result_data, args):
378364
eager_types = result_data["performance"]["datatype"]["eager"]
379365
compiled_types = result_data["performance"]["datatype"]["compiled"]
380366

381-
def _pick_key(dtype):
382-
if dtype in ("torch.float64", "torch.double"):
383-
return "[all_close_atol8_rtol5]"
384-
if dtype in ("torch.float32", "torch.float"):
385-
return "[all_close_atol8_rtol5]"
386-
if dtype in ("torch.float16", "torch.bfloat16"):
387-
return "[all_close_atol3_rtol2]"
388-
# float8
389-
if dtype in ("torch.float8_e5m2", "torch.float8_e4m3fn"):
390-
return "[all_close_atol2_rtol1]"
391-
# int / bool
392-
if "int" in dtype or dtype == "torch.bool":
393-
return "[equal]"
394-
# complex
395-
if dtype in ("torch.complex64", "torch.complex128"):
396-
return "[all_close_atol8_rtol5]"
397-
# default
398-
return "[all_close_atol8_rtol5]"
399-
400-
for idx in range(len(compiled_out)):
401-
dtype = compiled_types[idx]
402-
cmp_str = result_data["correctness"].get(_pick_key(dtype), "")
403-
tokens = cmp_str.split()
404-
if idx >= len(tokens) or tokens[idx] != "1":
405-
return True
406-
407-
return False
408-
409367

410368
def get_cmp_equal(expected_out, compiled_out):
411369
return " ".join(

0 commit comments

Comments
 (0)