Skip to content

Commit 514eb77

Browse files
refactor
1 parent 85de4ca commit 514eb77

File tree

5 files changed

+177
-0
lines changed

5 files changed

+177
-0
lines changed

examples/batch_generation.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import sys
2+
sys.path.append("..")
3+
from models.llama import LLM
4+
import argparse
5+
import torch
6+
from transformers import AutoTokenizer
7+
import jsonlines
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument('--model', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct",help='model')
10+
parser.add_argument('--T', type=int, default=2000, help='repeat times')
11+
parser.add_argument('--B', type=int, default=2, help='batch size')
12+
parser.add_argument('--M', type=int, default=4096, help='max length')
13+
parser.add_argument('--D', type=int, default=1, help='dec length')
14+
parser.add_argument('--G', type=int, default=32, help='generation length')
15+
parser.add_argument('--K', type=int, default=10, help='K')
16+
parser.add_argument('--L', type=int, default=150, help='K')
17+
args = parser.parse_args()
18+
print(args)
19+
MAX_LEN = args.M
20+
DEC_LEN = args.D
21+
GEN_LEN = args.G
22+
BATCH_SIZE = args.B
23+
MODEL_NAME = args.model
24+
DTYPE = torch.bfloat16
25+
DEVICE = "cuda:0"
26+
T = args.T
27+
WARM_UP = 10
28+
29+
with open("../data/data4k.jsonl") as f:
30+
d = jsonlines.Reader(f)
31+
for idx, item in enumerate(d):
32+
data = item
33+
break
34+
35+
llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=BATCH_SIZE, device=DEVICE, dtype=DTYPE)
36+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
37+
text = data["input"]
38+
input_ids = tokenizer.encode(text=text, return_tensors="pt").to(device=DEVICE)
39+
PREFIX_LEN = input_ids.shape[1]
40+
41+
position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0).repeat(BATCH_SIZE, 1)
42+
43+
batch_logits = []
44+
for i in range(BATCH_SIZE):
45+
logits = llm.prefill(input_ids, i)
46+
batch_logits.append(logits)
47+
48+
logits = torch.cat(batch_logits, dim=0)
49+
generated_tokens = []
50+
prefix_len = input_ids.shape[1]
51+
for k in range(GEN_LEN):
52+
input_ids = logits.argmax(dim=-1)
53+
logits = llm.inference(input_ids=input_ids, position_ids=position_ids[:,prefix_len + k:prefix_len + k + 1])
54+
generated_tokens.append(input_ids)
55+
if input_ids[0].item() in [128000, 128001, 128008, 128009]:
56+
break
57+
generated_tokens = torch.cat(generated_tokens, dim=1).to(device="cpu")
58+
decoded_texts = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
59+
print(decoded_texts)
60+
61+
62+
63+
64+

examples/bench.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import sys
2+
sys.path.append("..")
3+
from models.llama import LLM
4+
import argparse
5+
import torch
6+
from transformers import AutoTokenizer
7+
import jsonlines
8+
import time
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--model', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct",help='model')
11+
parser.add_argument('--B', type=int, default=1, help='batch size')
12+
parser.add_argument('--M', type=int, default=98304, help='max length')
13+
parser.add_argument('--D', type=int, default=1, help='dec length')
14+
parser.add_argument('--P', type=int, default=98000, help='prefill length')
15+
parser.add_argument('--G', type=int, default=128, help='generation length')
16+
parser.add_argument('--K', type=int, default=10, help='K')
17+
parser.add_argument('--L', type=int, default=150, help='L')
18+
args = parser.parse_args()
19+
print(args)
20+
MAX_LEN = args.M
21+
DEC_LEN = args.D
22+
GEN_LEN = args.G
23+
B = args.B
24+
MODEL_NAME = args.model
25+
DTYPE = torch.bfloat16
26+
PREFIX_LEN = args.P
27+
DEVICE = "cuda:0"
28+
WARM_UP = 32
29+
30+
with open("../data/data.jsonl") as f:
31+
d = jsonlines.Reader(f)
32+
for idx, item in enumerate(d):
33+
data = item
34+
break
35+
36+
llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=B, device=DEVICE, dtype=DTYPE)
37+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
38+
text = data["input"]
39+
input_ids = tokenizer.encode(text=text, return_tensors="pt").to(device=DEVICE)
40+
input_ids = input_ids[:,:PREFIX_LEN].repeat(B, 1)
41+
position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0).repeat(B, 1)
42+
43+
for i in range(B):
44+
logits = llm.prefill(input_ids=input_ids[i:i+1], request_id=i)
45+
46+
generated = input_ids[0].tolist()
47+
for k in range(WARM_UP):
48+
logits = llm.inference(input_ids=input_ids[:, 128+k:128+k+1], position_ids=position_ids[:,PREFIX_LEN + k:PREFIX_LEN + k + 1])
49+
50+
torch.cuda.synchronize()
51+
t1 = time.time()
52+
for k in range(GEN_LEN):
53+
logits = llm.inference(input_ids=input_ids[:, WARM_UP+k:WARM_UP+k+1], position_ids=position_ids[:,WARM_UP + PREFIX_LEN + k: WARM_UP + PREFIX_LEN + k + 1])
54+
55+
torch.cuda.synchronize()
56+
t2 = time.time()
57+
58+
print("Decoding Latency {:.2f} ms/token".format(1000 * (t2 - t1)/GEN_LEN))
59+
print("Decoding Throughput {:.2f} token/s".format(B * GEN_LEN / (t2 - t1)))

examples/bench.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
numactl -C 0-31,52-83 -m 0,1 python bench.py --B 1 --K 0 --L 150 --model codellama/CodeLlama-7b-Instruct-hf --M 16384 --P 16000
2+
numactl -C 0-31,52-83 -m 0,1 python bench.py --B 4 --K 0 --L 150 --model codellama/CodeLlama-7b-Instruct-hf --M 16384 --P 16000
3+
numactl -C 0-31,52-83 -m 0,1 python bench.py --B 8 --K 0 --L 150 --model codellama/CodeLlama-7b-Instruct-hf --M 16384 --P 16000
4+
# numactl -C 0-31,52-83 -m 0,1 python bench.py --B 1 --K 10 --L 170 --model codellama/CodeLlama-7b-Instruct-hf --M 131072 --P 128000
5+
6+
7+
# numactl -C 0-31,52-83 -m 0,1 python bench.py --B 12 --K 9 --L 120 --model codellama/CodeLlama-7b-Instruct-hf --M 131072 --P 128000
8+
9+
10+
# numactl -C 0-31,52-83 -m 0,1 python bench.py --B 12 --K 8 --L 75 --model codellama/CodeLlama-7b-Instruct-hf --M 131072 --P 128000

examples/generation.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import sys
2+
sys.path.append("..")
3+
from models.llama import LLM
4+
import argparse
5+
import torch
6+
from transformers import AutoTokenizer
7+
import jsonlines
8+
from models.template import Templates
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--model', type=str, default="meta-llama/Meta-Llama-3.1-8B-Instruct",help='model')
11+
parser.add_argument('--M', type=int, default=8192, help='max length')
12+
parser.add_argument('--D', type=int, default=1, help='dec length')
13+
parser.add_argument('--G', type=int, default=256, help='generation length')
14+
parser.add_argument('--K', type=int, default=10, help='K')
15+
parser.add_argument('--L', type=int, default=150, help='K')
16+
parser.add_argument('--data', type=str, default="../data/story.txt", help='source data file')
17+
parser.add_argument('--template', type=str, default="meta-llama3", help='chat template')
18+
args = parser.parse_args()
19+
print(args)
20+
MAX_LEN = args.M
21+
DEC_LEN = args.D
22+
GEN_LEN = args.G
23+
MODEL_NAME = args.model
24+
DTYPE = torch.bfloat16
25+
DEVICE = "cuda:0"
26+
chat_template = Templates[args.template]
27+
llm = LLM(K=args.K, L=args.L, max_length=MAX_LEN, model_name=args.model, batch_size=1, device=DEVICE, dtype=DTYPE, generation_buffer=args.G + 32)
28+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
29+
with open(args.data, "r", encoding="utf-8") as file:
30+
content = file.read()
31+
content = chat_template.format(content)
32+
input_ids = tokenizer.encode(text=content, return_tensors="pt")
33+
context = tokenizer.decode(input_ids[0], skip_special_tokens=True)
34+
print(context)
35+
input_ids = input_ids.to(DEVICE)
36+
PREFIX_LEN = input_ids.shape[1]
37+
position_ids = torch.arange(MAX_LEN, device=DEVICE).unsqueeze(0)
38+
generated = llm.generate(input_ids, max_tokens=args.G)
39+
text = tokenizer.decode(generated, skip_special_tokens=True)
40+
print("\033[32m" + text + "\033[0m")
41+
42+
43+
44+

models/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)