Skip to content

Commit 0a037b2

Browse files
author
Sanggyu Lee
committed
add decode.py to export LlamaModel decode phase
1 parent ae0b0c3 commit 0a037b2

File tree

1 file changed

+69
-0
lines changed
  • test/modules/model/LlamaDecoderLayerWithKVCacheAndFusedAttention

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# User input
2+
prompt = "Lily picked up a flower."
3+
model_name = "Maykeye/TinyLLama-v0"
4+
5+
# Tokenizer
6+
from transformers import AutoTokenizer
7+
8+
tokenizer = AutoTokenizer.from_pretrained(model_name)
9+
tokenizer.pad_token = tokenizer.eos_token
10+
tokenizer.padding_side = "right"
11+
inputs = tokenizer(
12+
prompt,
13+
return_tensors="pt",
14+
padding="max_length",
15+
max_length=30,
16+
truncation=True,
17+
)
18+
19+
# Generator
20+
import torch
21+
22+
from transformers import AutoModelForCausalLM
23+
24+
model = AutoModelForCausalLM.from_pretrained(model_name)
25+
model.eval()
26+
27+
from tico.utils.record_input import RecordingInput
28+
29+
# past_key_values
30+
# ---------------
31+
# During prefill, "past_key_values" not None, but an empty Cache instance.
32+
# Passing None makes torch.export happy.
33+
34+
35+
input_to_remove = [
36+
"attention_mask",
37+
# For left pad, [0, ⋯, 0, 1, ⋯, 1]
38+
# For right right pad, [1, ⋯, 1, 0, ⋯, 0]
39+
# ( 0 is pad-token )
40+
# This script uses right pad and pass all-1 attention mask (including pad).
41+
# Npu computes all positions whether it is pad or not.
42+
]
43+
condition_fn = lambda args_dict: args_dict["past_key_values"].get_seq_length() != 0
44+
45+
with torch.no_grad(), RecordingInput(
46+
model, condition_fn, input_to_remove=input_to_remove
47+
) as rec:
48+
outputs = model.generate(
49+
**inputs,
50+
max_new_tokens=32,
51+
do_sample=False,
52+
pad_token_id=tokenizer.eos_token_id,
53+
)
54+
captured_input = rec.captured_input
55+
56+
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
57+
print(generated_text)
58+
59+
# Tico
60+
import tico
61+
from tico.serialize.operators.onert.op_attention import llama_attention_forward_adapter
62+
from transformers.models.llama.modeling_llama import LlamaAttention
63+
64+
LlamaAttention.forward = llama_attention_forward_adapter
65+
66+
model = AutoModelForCausalLM.from_pretrained(model_name)
67+
model.eval()
68+
circle_model = tico.convert(model, captured_input)
69+
circle_model.save(f"tinyllama.decode.circle")

0 commit comments

Comments
 (0)