Skip to content

Commit e3051fa

Browse files
add benchmarking script
1 parent 0ac1452 commit e3051fa

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""
2+
Basic benchmark for text generation.
3+
4+
Usage: python benchmarking/int8/int8_benchmark.py
5+
"""
6+
7+
import time
8+
9+
import torch
10+
from torch.profiler import ProfilerActivity, profile
11+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
12+
13+
MAX_NEW_TOKENS = 128
14+
model_name = "meta-llama/Llama-3.1-8B"
15+
16+
text = "Below is a question. I need an answer.\n\nExplain machine learning: "
17+
tokenizer = AutoTokenizer.from_pretrained(model_name)
18+
input_ids = tokenizer([text] * 8, return_tensors="pt").input_ids.to(0)
19+
20+
max_memory = f"{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB"
21+
22+
model = AutoModelForCausalLM.from_pretrained(
23+
model_name,
24+
device_map="auto",
25+
quantization_config=BitsAndBytesConfig(
26+
load_in_8bit=True,
27+
llm_int8_threshold=6.0,
28+
),
29+
attn_implementation="sdpa",
30+
torch_dtype=torch.float16,
31+
)
32+
33+
print(model)
34+
35+
# warmup
36+
print("Warmup...")
37+
for i in range(3):
38+
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
39+
40+
print("Profiler starting...")
41+
with profile(
42+
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
43+
with_modules=True,
44+
with_stack=True,
45+
) as prof:
46+
model.generate(input_ids, max_new_tokens=1)
47+
48+
print(
49+
prof.key_averages().table(
50+
sort_by="cpu_time_total",
51+
max_name_column_width=50,
52+
top_level_events_only=True,
53+
row_limit=50,
54+
)
55+
)
56+
57+
torch.cuda.synchronize()
58+
59+
60+
print("Generating...")
61+
num = 0
62+
time_1 = time.time()
63+
for i in range(5):
64+
generated_ids = model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
65+
num += len(generated_ids[0])
66+
67+
print("=" * 40)
68+
print(f"Example:\n{tokenizer.decode(generated_ids[0])}")
69+
print("=" * 40)
70+
print(f"Speed: {num/(time.time() - time_1)}token/s")

0 commit comments

Comments
 (0)