Skip to content

Commit 9348d76

Browse files
committed
Update
1 parent 3245cf9 commit 9348d76

File tree

2 files changed

+12
-8
lines changed

2 files changed

+12
-8
lines changed

graph_net/benchmark_demo.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ echo "[$(date)] Script started in background (PID: $$)" | tee -a "$global_log"
3434
} >> "$global_log" 2>&1
3535

3636

37-
# nohup bash /work/GraphNet/graph_net/benchmark_demo.sh > /dev/ 2>&1 &
37+
# nohup bash /work/GraphNet/graph_net/benchmark_demo.sh &
3838

39-
# python3 -m graph_net.analysis --test-compiler-log-file /work/GraphNet/torchvision_cuda.log
39+
# python3 -m graph_net.analysis --test-compiler-log-file /work/GraphNet/global.log
4040

4141
# python3 -m graph_net.analysis --benchmark-path "${benchmark_dir}"

graph_net/torch/test_compiler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -348,23 +348,27 @@ def get_cmp_all_close(expected_out, compiled_out, atol, rtol):
348348

349349
def get_cmp_max_diff(expected_out, compiled_out):
350350
return " ".join(
351-
str(torch.max(torch.abs(a - b)).item())
351+
str(torch.max(torch.abs(a.float() - b.float())).item())
352352
for a, b in zip(expected_out, compiled_out)
353353
)
354354

355355

356356
def get_cmp_mean_diff(expected_out, compiled_out):
357357
return " ".join(
358-
str(torch.mean(torch.abs(a - b)).item())
358+
str(torch.mean(torch.abs(a.float() - b.float())).item())
359359
for a, b in zip(expected_out, compiled_out)
360360
)
361361

362362

363363
def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
364-
return " ".join(
365-
str(torch.sum(~torch.isclose(a, b, atol=atol, rtol=rtol)).item())
366-
for a, b in zip(expected_out, compiled_out)
367-
)
364+
results = []
365+
for a, b in zip(expected_out, compiled_out):
366+
if a.is_floating_point() and b.is_floating_point():
367+
diff_count = torch.sum(~torch.isclose(a, b, atol=atol, rtol=rtol)).item()
368+
else:
369+
diff_count = torch.sum(a != b).item()
370+
results.append(str(diff_count))
371+
return " ".join(results)
368372

369373

370374
def test_multi_models(args):

0 commit comments

Comments
 (0)