Skip to content

Commit bc7ed34

Browse files
committed
Update
1 parent 9348d76 commit bc7ed34

File tree

2 files changed

+30
-13
lines changed

2 files changed

+30
-13
lines changed

graph_net/analysis.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,20 @@ def analysis(args):
7575
# inductor_log = os.path.join(args.test_compiler_log_file)
7676
# inductor_speedup = read_speedups_from_log(inductor_log)
7777
inductor_speedup = read_speedups_from_json(args.benchmark_path)
78+
print(f"Find {len(inductor_speedup)} samples.")
7879
log2_speedups = np.log2(inductor_speedup)
79-
data["log2(speedup)"].extend(log2_speedups)
80-
data["Compiler"].extend(["torch.inductor"] * len(log2_speedups))
80+
81+
mask = log2_speedups <= 2
82+
filtered_log2_speedups = log2_speedups[mask]
83+
filtered_count = len(filtered_log2_speedups)
84+
print(
85+
f"After filtering, {filtered_count} samples remain (removed {len(log2_speedups) - filtered_count} outliers)."
86+
)
87+
88+
data["log2(speedup)"].extend(filtered_log2_speedups)
89+
data["Compiler"].extend(["torch.inductor"] * len(filtered_log2_speedups))
90+
# data["log2(speedup)"].extend(log2_speedups)
91+
# data["Compiler"].extend(["torch.inductor"] * len(log2_speedups))
8192

8293
# C: tvm (Simulate)
8394
# data["log2(speedup)"].extend(

graph_net/torch/test_compiler.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,12 @@ def get_timing_stats(elapsed_times: list[float]):
199199

200200

201201
def measure_performance(model_call, args, compiler):
202-
if args.device == "cuda":
202+
if "cuda" in args.device:
203203
times = time_execution_with_cuda_event(
204204
model_call,
205205
num_warmup=args.warmup,
206206
num_trials=args.trials,
207-
device=torch.device("cuda:0"),
207+
device=torch.device(args.device),
208208
)
209209
else:
210210
times = time_execution_naive(
@@ -243,8 +243,10 @@ def test_single_model(args):
243243
},
244244
}
245245

246-
if args.device == "cuda":
247-
result_data["configuration"]["hardware"] = torch.cuda.get_device_name(0)
246+
if "cuda" in args.device:
247+
result_data["configuration"]["hardware"] = torch.cuda.get_device_name(
248+
args.device
249+
)
248250
elif args.device == "cpu":
249251
result_data["configuration"]["hardware"] = platform.processor()
250252
else:
@@ -335,38 +337,42 @@ def print_and_store_cmp(key, func, **kwargs):
335337

336338
def get_cmp_equal(expected_out, compiled_out):
337339
return " ".join(
338-
str(int(torch.equal(a, b))) for a, b in zip(expected_out, compiled_out)
340+
str(int(torch.equal(a.cpu(), b.cpu())))
341+
for a, b in zip(expected_out, compiled_out)
339342
)
340343

341344

342345
def get_cmp_all_close(expected_out, compiled_out, atol, rtol):
343346
return " ".join(
344-
str(int(torch.allclose(a, b, atol=atol, rtol=rtol)))
347+
str(int(torch.allclose(a.cpu(), b.cpu(), atol=atol, rtol=rtol)))
345348
for a, b in zip(expected_out, compiled_out)
346349
)
347350

348351

349352
def get_cmp_max_diff(expected_out, compiled_out):
350353
return " ".join(
351-
str(torch.max(torch.abs(a.float() - b.float())).item())
354+
str(torch.max(torch.abs(a.cpu().float() - b.cpu().float())).item())
352355
for a, b in zip(expected_out, compiled_out)
353356
)
354357

355358

356359
def get_cmp_mean_diff(expected_out, compiled_out):
357360
return " ".join(
358-
str(torch.mean(torch.abs(a.float() - b.float())).item())
361+
str(torch.mean(torch.abs(a.cpu().float() - b.cpu().float())).item())
359362
for a, b in zip(expected_out, compiled_out)
360363
)
361364

362365

363366
def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
364367
results = []
365368
for a, b in zip(expected_out, compiled_out):
366-
if a.is_floating_point() and b.is_floating_point():
367-
diff_count = torch.sum(~torch.isclose(a, b, atol=atol, rtol=rtol)).item()
369+
a_cpu, b_cpu = a.cpu(), b.cpu()
370+
if a_cpu.is_floating_point() and b_cpu.is_floating_point():
371+
diff_count = torch.sum(
372+
~torch.isclose(a_cpu, b_cpu, atol=atol, rtol=rtol)
373+
).item()
368374
else:
369-
diff_count = torch.sum(a != b).item()
375+
diff_count = torch.sum(a_cpu != b_cpu).item()
370376
results.append(str(diff_count))
371377
return " ".join(results)
372378

0 commit comments

Comments
 (0)