diff --git a/examples/quickstart/hf_llm.py b/examples/quickstart/hf_llm.py index 987a8784d1..1c1058b19c 100644 --- a/examples/quickstart/hf_llm.py +++ b/examples/quickstart/hf_llm.py @@ -32,7 +32,7 @@ def generate(model, inp, cache=None): print(tokenizer.decode(out[0].tolist())) print("\nGenerating with PyTorch eager:") - eager_time = benchmark_n(2, generate, model, inp) + eager_time = benchmark_n(2, generate, model, inp, device=device) thunder_model = thunder.compile( model, @@ -40,7 +40,7 @@ def generate(model, inp, cache=None): ) print("\nGenerating with Thunder:") - thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static") + thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static", device=device) print(f"\nEager: {eager_time:.2f}ms") print(f"Thunder: {thunder_time:.2f}ms") diff --git a/thunder/benchmarks/benchmark_hf.py b/thunder/benchmarks/benchmark_hf.py index 08f3cd3c34..68bd56dc81 100644 --- a/thunder/benchmarks/benchmark_hf.py +++ b/thunder/benchmarks/benchmark_hf.py @@ -35,13 +35,13 @@ def setup_config(self): def run_and_profile(tag: str, fn, model, inp, compiled_models: dict[str, torch.nn.Module], cache=None): print(f"[{tag}] running PyTorch eager") - eager_time = benchmark_n(10, fn, model, inp) + eager_time = benchmark_n(10, fn, model, inp, device=device) timings = [f"Eager: {eager_time:.2f}ms"] for name, compiled_model in compiled_models.items(): print(f"[{tag}] running Thunder ({name})") - thunder_time = benchmark_n(10, fn, compiled_model, inp, cache=cache) + thunder_time = benchmark_n(10, fn, compiled_model, inp, cache=cache, device=device) timings.append(f"Thunder ({name}): {thunder_time:.2f}ms") if save_traces: diff --git a/thunder/dev_utils/benchmark.py b/thunder/dev_utils/benchmark.py index 8fb3a9ef8c..920362f5fe 100644 --- a/thunder/dev_utils/benchmark.py +++ b/thunder/dev_utils/benchmark.py @@ -1,20 +1,31 @@ from functools import partial +import time import torch -def benchmark_n(n, model_or_fn, /, *args, **kwargs): +def benchmark_n(n, model_or_fn, /, *args, device: str = "cuda:0", **kwargs): for _ in range(n): _ = model_or_fn(*args, **kwargs) - start_event = torch.cuda.Event(enable_timing=True) - end_event = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start_event.record() - for _ in range(n): - _ = model_or_fn(*args, **kwargs) - end_event.record() - torch.cuda.synchronize() - return start_event.elapsed_time(end_event) / n + + use_cuda_events = device.startswith("cuda") and torch.cuda.is_available() + + if use_cuda_events: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start_event.record() + for _ in range(n): + _ = model_or_fn(*args, **kwargs) + end_event.record() + torch.cuda.synchronize() + return start_event.elapsed_time(end_event) / n + else: + start = time.perf_counter() + for _ in range(n): + _ = model_or_fn(*args, **kwargs) + end = time.perf_counter() + return (end - start) * 1000.0 / n benchmark = partial(benchmark_n, 10)