Skip to content

Commit 56fd4e3

Browse files
committed
Fix hf_llm example run on CPU
Pass device to `benchmark_n` function in order to be able to run on CPU and GPU. So far, `torch.cuda.*` calls are used in `benchmark_n` unconditionally
1 parent 6720e82 commit 56fd4e3

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

examples/quickstart/hf_llm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,15 +32,15 @@ def generate(model, inp, cache=None):
3232
print(tokenizer.decode(out[0].tolist()))
3333

3434
print("\nGenerating with PyTorch eager:")
35-
eager_time = benchmark_n(2, generate, model, inp)
35+
eager_time = benchmark_n(2, generate, model, inp, device=device)
3636

3737
thunder_model = thunder.compile(
3838
model,
3939
recipe="hf-transformers",
4040
)
4141

4242
print("\nGenerating with Thunder:")
43-
thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static")
43+
thunder_time = benchmark_n(2, generate, thunder_model, inp, cache="static", device=device)
4444

4545
print(f"\nEager: {eager_time:.2f}ms")
4646
print(f"Thunder: {thunder_time:.2f}ms")

thunder/benchmarks/benchmark_hf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,13 @@ def setup_config(self):
3535

3636
def run_and_profile(tag: str, fn, model, inp, compiled_models: dict[str, torch.nn.Module], cache=None):
3737
print(f"[{tag}] running PyTorch eager")
38-
eager_time = benchmark_n(10, fn, model, inp)
38+
eager_time = benchmark_n(10, fn, model, inp, device=device)
3939

4040
timings = [f"Eager: {eager_time:.2f}ms"]
4141

4242
for name, compiled_model in compiled_models.items():
4343
print(f"[{tag}] running Thunder ({name})")
44-
thunder_time = benchmark_n(10, fn, compiled_model, inp, cache=cache)
44+
thunder_time = benchmark_n(10, fn, compiled_model, inp, cache=cache, device=device)
4545
timings.append(f"Thunder ({name}): {thunder_time:.2f}ms")
4646

4747
if save_traces:

thunder/dev_utils/benchmark.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
from functools import partial
2+
import time
23

34
import torch
45

56

6-
def benchmark_n(n, model_or_fn, /, *args, **kwargs):
7+
def benchmark_n(n, model_or_fn, /, *args, device: str = "cuda:0", **kwargs):
78
for _ in range(n):
89
_ = model_or_fn(*args, **kwargs)
9-
start_event = torch.cuda.Event(enable_timing=True)
10-
end_event = torch.cuda.Event(enable_timing=True)
11-
torch.cuda.synchronize()
12-
start_event.record()
13-
for _ in range(n):
14-
_ = model_or_fn(*args, **kwargs)
15-
end_event.record()
16-
torch.cuda.synchronize()
17-
return start_event.elapsed_time(end_event) / n
10+
11+
use_cuda_events = device.startswith("cuda") and torch.cuda.is_available()
12+
13+
if use_cuda_events:
14+
start_event = torch.cuda.Event(enable_timing=True)
15+
end_event = torch.cuda.Event(enable_timing=True)
16+
torch.cuda.synchronize()
17+
start_event.record()
18+
for _ in range(n):
19+
_ = model_or_fn(*args, **kwargs)
20+
end_event.record()
21+
torch.cuda.synchronize()
22+
return start_event.elapsed_time(end_event) / n
23+
else:
24+
start = time.perf_counter()
25+
for _ in range(n):
26+
_ = model_or_fn(*args, **kwargs)
27+
end = time.perf_counter()
28+
return (end - start) * 1000.0 / n
1829

1930

2031
benchmark = partial(benchmark_n, 10)

0 commit comments

Comments
 (0)