Skip to content

Commit 2015127

Browse files
Add example for torch.compile e2e inference
1 parent fa188f6 commit 2015127

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

examples/compile_inference.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import torch
2+
import torch._dynamo
3+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
4+
5+
# torch._dynamo.config.suppress_errors = True
6+
7+
torch.set_float32_matmul_precision("high")
8+
9+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
10+
11+
# torch._dynamo.config.capture_dynamic_output_shape_ops = True
12+
13+
model_id = "google/gemma-2-2b-it"
14+
# model_id = "Qwen/Qwen2.5-7B"
15+
16+
tokenizer = AutoTokenizer.from_pretrained(model_id)
17+
model = AutoModelForCausalLM.from_pretrained(
18+
model_id,
19+
quantization_config=quantization_config,
20+
device_map="auto",
21+
torch_dtype=torch.bfloat16,
22+
)
23+
24+
input_text = "Write me a poem about Machine Learning."
25+
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
26+
27+
# model.forward = torch.compile(model.forward, fullgraph=True)
28+
29+
model = torch.compile(model)
30+
31+
outputs = model.generate(**input_ids, max_new_tokens=32)
32+
print(tokenizer.decode(outputs[0]))

0 commit comments

Comments
 (0)