Skip to content

Commit 6092ffe

Browse files
committed
Record 2 types of speedup
1 parent 7431a22 commit 6092ffe

File tree

1 file changed

+41
-101
lines changed

1 file changed

+41
-101
lines changed

graph_net/torch/test_compiler.py

Lines changed: 41 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -85,85 +85,6 @@ def naive_timer(duration_box, synchronizer_func):
8585
duration_box.value = (end - start) * 1000 # Store in milliseconds
8686

8787

88-
def time_execution_with_cuda_event(
89-
kernel_fn: Callable,
90-
*args,
91-
num_warmup: int = 3,
92-
num_trials: int = 10,
93-
verbose: bool = True,
94-
device: torch.device = None,
95-
) -> List[float]:
96-
"""
97-
Acknowledgement: We introduce evaluation method in https://github.com/ScalingIntelligence/KernelBench to enhance function.
98-
99-
Time a CUDA kernel function over multiple trials using torch.cuda.Event
100-
101-
Args:
102-
kernel_fn: Function to time
103-
*args: Arguments to pass to kernel_fn
104-
num_trials: Number of timing trials to run
105-
verbose: Whether to print per-trial timing info
106-
device: CUDA device to use, if None, use current device
107-
108-
Returns:
109-
List of elapsed times in milliseconds
110-
"""
111-
if device is None:
112-
if verbose:
113-
print(f"Using current device: {torch.cuda.current_device()}")
114-
device = torch.cuda.current_device()
115-
116-
# Warm ups
117-
for _ in range(num_warmup):
118-
kernel_fn(*args)
119-
torch.cuda.synchronize(device=device)
120-
121-
print(
122-
f"[Profiling] Using device: {device} {torch.cuda.get_device_name(device)}, warm up {num_warmup}, trials {num_trials}"
123-
)
124-
elapsed_times = []
125-
126-
# Actual trials
127-
for trial in range(num_trials):
128-
# create event marker default is not interprocess
129-
start_event = torch.cuda.Event(enable_timing=True)
130-
end_event = torch.cuda.Event(enable_timing=True)
131-
132-
start_event.record()
133-
kernel_fn(*args)
134-
end_event.record()
135-
136-
# Synchronize to ensure the events have completed
137-
torch.cuda.synchronize(device=device)
138-
139-
# Calculate the elapsed time in milliseconds
140-
elapsed_time_ms = start_event.elapsed_time(end_event)
141-
if verbose:
142-
print(f"Trial {trial + 1}: {elapsed_time_ms:.3g} ms")
143-
elapsed_times.append(elapsed_time_ms)
144-
145-
return elapsed_times
146-
147-
148-
def time_execution_naive(
149-
model_call, synchronizer_func, num_warmup: int = 3, num_trials: int = 10
150-
):
151-
print(
152-
f"[Profiling] Using device: {args.device} {platform.processor()}, warm up {num_warmup}, trials {num_trials}"
153-
)
154-
for _ in range(num_warmup):
155-
model_call()
156-
157-
times = []
158-
for i in range(num_trials):
159-
duration_box = DurationBox(-1)
160-
with naive_timer(duration_box, synchronizer_func):
161-
model_call()
162-
print(f"Trial {i + 1}: {duration_box.value:.2f} ms")
163-
times.append(duration_box.value)
164-
return times
165-
166-
16788
def get_timing_stats(elapsed_times: List[float]):
16889
stats = {
16990
"mean": float(f"{np.mean(elapsed_times):.3g}"),
@@ -189,7 +110,7 @@ def measure_performance(model_call, args, compiler):
189110
device = torch.device(args.device)
190111
hardware_name = torch.cuda.get_device_name(device)
191112
print(
192-
f"[Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}"
113+
f"{args.log_prompt} [Profiling] Using device: {args.device} {hardware_name}, warm up {args.warmup}, trials {args.trials}"
193114
)
194115

195116
e2e_times = []
@@ -292,11 +213,6 @@ def test_single_model(args):
292213
result_data["performance"]["eager"] = eager_stats
293214
result_data["performance"]["compiled"] = compiled_stats
294215

295-
eager_time_ms = eager_stats.get("e2e", {}).get("mean", 0)
296-
compiled_time_ms = compiled_stats.get("e2e", {}).get("mean", 0)
297-
# Using e2e time to calculate speedup
298-
result_data["performance"]["speedup"] = eager_time_ms / compiled_time_ms
299-
300216
expected_out = eager_model_call()
301217
compiled_out = compiled_model_call()
302218

@@ -313,44 +229,68 @@ def print_and_store_cmp(key, func, **kwargs):
313229
file=sys.stderr,
314230
)
315231

316-
print_and_store_cmp("equal", get_cmp_equal)
232+
print_and_store_cmp("[equal]", get_cmp_equal)
317233
print_and_store_cmp(
318-
"all_close_atol8_rtol8", get_cmp_all_close, atol=1e-8, rtol=1e-8
234+
"[all_close_atol8_rtol8]", get_cmp_all_close, atol=1e-8, rtol=1e-8
319235
)
320236
print_and_store_cmp(
321-
"all_close_atol8_rtol5", get_cmp_all_close, atol=1e-8, rtol=1e-5
237+
"[all_close_atol8_rtol5]", get_cmp_all_close, atol=1e-8, rtol=1e-5
322238
)
323239
print_and_store_cmp(
324-
"all_close_atol5_rtol5", get_cmp_all_close, atol=1e-5, rtol=1e-5
240+
"[all_close_atol5_rtol5]", get_cmp_all_close, atol=1e-5, rtol=1e-5
325241
)
326242
print_and_store_cmp(
327-
"all_close_atol3_rtol2", get_cmp_all_close, atol=1e-3, rtol=1e-2
243+
"[all_close_atol3_rtol2]", get_cmp_all_close, atol=1e-3, rtol=1e-2
328244
)
329245
print_and_store_cmp(
330-
"all_close_atol2_rtol1", get_cmp_all_close, atol=1e-2, rtol=1e-1
246+
"[all_close_atol2_rtol1]", get_cmp_all_close, atol=1e-2, rtol=1e-1
331247
)
332-
print_and_store_cmp("max_diff", get_cmp_max_diff)
333-
print_and_store_cmp("mean_diff", get_cmp_mean_diff)
248+
print_and_store_cmp("[max_diff]", get_cmp_max_diff)
249+
print_and_store_cmp("[mean_diff]", get_cmp_mean_diff)
334250
print_and_store_cmp(
335-
"diff_count_atol8_rtol8", get_cmp_diff_count, atol=1e-8, rtol=1e-8
251+
"[diff_count_atol8_rtol8]", get_cmp_diff_count, atol=1e-8, rtol=1e-8
336252
)
337253
print_and_store_cmp(
338-
"diff_count_atol8_rtol5", get_cmp_diff_count, atol=1e-8, rtol=1e-5
254+
"[diff_count_atol8_rtol5]", get_cmp_diff_count, atol=1e-8, rtol=1e-5
339255
)
340256
print_and_store_cmp(
341-
"diff_count_atol5_rtol5", get_cmp_diff_count, atol=1e-5, rtol=1e-5
257+
"[diff_count_atol5_rtol5]", get_cmp_diff_count, atol=1e-5, rtol=1e-5
342258
)
343259
print_and_store_cmp(
344-
"diff_count_atol3_rtol2", get_cmp_diff_count, atol=1e-3, rtol=1e-2
260+
"[diff_count_atol3_rtol2]", get_cmp_diff_count, atol=1e-3, rtol=1e-2
345261
)
346262
print_and_store_cmp(
347-
"diff_count_atol2_rtol1", get_cmp_diff_count, atol=1e-2, rtol=1e-1
263+
"[diff_count_atol2_rtol1]", get_cmp_diff_count, atol=1e-2, rtol=1e-1
348264
)
349265

350-
print(
351-
f"{args.log_prompt} duration model_path:{args.model_path} eager:{eager_time_ms:.4f} compiled:{compiled_time_ms:.4f}",
352-
file=sys.stderr,
266+
eager_e2e_time_ms = eager_stats.get("e2e", {}).get("mean", 0)
267+
compiled_e2e_time_ms = compiled_stats.get("e2e", {}).get("mean", 0)
268+
269+
e2e_speedup = 0
270+
if eager_e2e_time_ms > 0 and compiled_e2e_time_ms > 0:
271+
e2e_speedup = eager_e2e_time_ms / compiled_e2e_time_ms
272+
result_data["performance"]["speedup"]["e2e"] = e2e_speedup
273+
274+
duration_log = (
275+
f"{args.log_prompt} [Duration] "
276+
f"eager_e2e:{eager_e2e_time_ms:.4f} compiled_e2e:{compiled_e2e_time_ms:.4f}"
353277
)
278+
speedup_log = f"{args.log_prompt} [Speedup] " f"e2e_speedup:{e2e_speedup:.4f}"
279+
280+
if "cuda" in args.device:
281+
eager_gpu_time_ms = eager_stats.get("gpu", {}).get("mean", 0)
282+
compiled_gpu_time_ms = compiled_stats.get("gpu", {}).get("mean", 0)
283+
284+
gpu_speedup = 0
285+
if eager_gpu_time_ms > 0 and compiled_gpu_time_ms > 0:
286+
gpu_speedup = eager_gpu_time_ms / compiled_gpu_time_ms
287+
result_data["performance"]["speedup"]["gpu"] = gpu_speedup
288+
289+
duration_log += f" eager_gpu:{eager_gpu_time_ms:.4f} compiled_gpu:{compiled_gpu_time_ms:.4f}"
290+
speedup_log += f" gpu_speedup:{gpu_speedup:.4f}"
291+
292+
print(duration_log, file=sys.stderr)
293+
print(speedup_log, file=sys.stderr)
354294

355295
if args.output_dir:
356296
os.makedirs(args.output_dir, exist_ok=True)

0 commit comments

Comments
 (0)