Skip to content

Commit 52b78b6

Browse files
committed
[fix]batch infer test
1 parent 3fc4d72 commit 52b78b6

File tree

2 files changed

+97
-52
lines changed

2 files changed

+97
-52
lines changed

test/model/test_settings/model_infer_batchs.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@ def test_model_inference(world_size, model_dir, model_class, batch_sizes, input_
99
workers = []
1010
for rank_id in range(world_size):
1111
model_kvargs = {
12+
"run_mode": "normal",
1213
"tp_rank": rank_id,
1314
"world_size": world_size,
1415
"weight_dir": model_dir,
15-
"max_total_token_num": 0 * (input_len + output_len),
16+
"max_total_token_num": None,
17+
"mem_faction": 0.8,
1618
"load_way": "HF",
19+
"batch_max_tokens": (input_len + output_len),
1720
"mode": mode,
1821
"max_req_num": max(batch_sizes),
22+
"graph_max_batch_size": max(batch_sizes),
23+
"graph_max_len_in_batch": (input_len + output_len),
1924
"max_seq_length": (input_len + output_len),
2025
}
2126

@@ -96,8 +101,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
96101
b_seq_len[i] = input_len
97102

98103
total_token_num = input_len * batch_size
104+
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
99105
logics = model_part.forward(
100-
batch_size, total_token_num, input_len, test_data, b_req_idx, b_start_loc, b_seq_len, is_prefill=True
106+
batch_size,
107+
total_token_num,
108+
input_len,
109+
test_data,
110+
mem_indexes,
111+
b_req_idx,
112+
b_start_loc,
113+
b_seq_len,
114+
is_prefill=True,
101115
)
102116
prob_out = torch.softmax(logics, dim=-1)
103117
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
@@ -107,11 +121,13 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
107121
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
108122
total_token_num += batch_size
109123
b_seq_len += 1
124+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
110125
logics = model_part.forward(
111126
batch_size,
112127
total_token_num,
113128
input_len + i + 1,
114129
torch.from_numpy(predict_ids).cuda().reshape(-1),
130+
mem_indexes,
115131
b_req_idx,
116132
b_start_loc,
117133
b_seq_len,
@@ -152,8 +168,17 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
152168
b_seq_len[i] = input_len
153169

154170
total_token_num = batch_size * input_len
171+
mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0])
155172
logics = model_part.forward(
156-
batch_size, total_token_num, input_len, test_data, b_req_idx, b_start_loc, b_seq_len, is_prefill=True
173+
batch_size,
174+
total_token_num,
175+
input_len,
176+
test_data,
177+
mem_indexes,
178+
b_req_idx,
179+
b_start_loc,
180+
b_seq_len,
181+
is_prefill=True,
157182
)
158183
prob_out = torch.softmax(logics, dim=-1)
159184
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
@@ -169,12 +194,13 @@ def tppart_model_infer(model_class, model_kvargs, batch_sizes, input_len, output
169194
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
170195
total_token_num += batch_size
171196
b_seq_len += 1
172-
197+
mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0])
173198
logics = model_part.forward(
174199
batch_size,
175200
total_token_num,
176201
input_len + i + 1,
177202
torch.from_numpy(predict_ids).cuda().reshape(-1),
203+
mem_indexes,
178204
b_req_idx,
179205
b_start_loc,
180206
b_seq_len,

test/model/test_settings/test_settings.py

Lines changed: 67 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,79 +2,78 @@
22
import sys
33
from model_infer_batchs import test_model_inference
44
from process_utils import kill_gpu_processes
5+
56
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))))
67
from datetime import datetime
78

89

910
from lightllm.models.bloom.model import BloomTpPartModel
1011
from lightllm.models.llama.model import LlamaTpPartModel
11-
from lightllm.models.llama_wquant.model import LlamaTpPartModelWQuant
1212
from lightllm.models.starcoder.model import StarcoderTpPartModel
13-
from lightllm.models.starcoder_wquant.model import StarcoderTpPartModelWQuant
1413
from lightllm.models.qwen.model import QWenTpPartModel
15-
from lightllm.models.qwen_wquant.model import QWenTpPartModelWQuant
16-
from lightllm.models.baichuan7b.model import Baichuan7bTpPartModel
17-
from lightllm.models.baichuan13b.model import Baichuan13bTpPartModel
18-
from lightllm.models.baichuan2_7b.model import Baichuan2_7bTpPartModel
1914
from lightllm.models.chatglm2.model import ChatGlm2TpPartModel
2015
from lightllm.models.internlm.model import InternlmTpPartModel
21-
from lightllm.models.yi.model import YiTpPartModel
2216

2317

2418
base_dir = "/nvme/"
2519

2620
model_to_class_and_path = {
27-
"llama-7b" : (LlamaTpPartModel, os.path.join(base_dir, "llama-7b")),
28-
"llama-13b" :(LlamaTpPartModel, os.path.join(base_dir, "")),
29-
"internal-20b" : (InternlmTpPartModel, os.path.join(base_dir, "")),
30-
"llama-65b" : (LlamaTpPartModel, os.path.join(base_dir, "")),
31-
"llama2-70b" : (LlamaTpPartModel, os.path.join(base_dir, "")),
32-
"qwen-7b" : (QWenTpPartModelWQuant, os.path.join(base_dir, "")),
33-
"qwen-14b" : (QWenTpPartModelWQuant, os.path.join(base_dir, "")),
34-
"chatglm2-6b" : (ChatGlm2TpPartModel, os.path.join(base_dir, ""))
21+
"llama-7b": (LlamaTpPartModel, os.path.join(base_dir, "llama-7b")),
22+
"llama-13b": (LlamaTpPartModel, os.path.join(base_dir, "")),
23+
"internal-20b": (InternlmTpPartModel, os.path.join(base_dir, "")),
24+
"llama-65b": (LlamaTpPartModel, os.path.join(base_dir, "")),
25+
"llama2-70b": (LlamaTpPartModel, os.path.join(base_dir, "")),
26+
"chatglm2-6b": (ChatGlm2TpPartModel, os.path.join(base_dir, "")),
27+
"llama3-8b": (LlamaTpPartModel, "/data/models/Meta-Llama-3-8B-Instruct"),
3528
}
3629

30+
3731
def test_all_setting(gpu_name, model_name, mode, log_dir, world_sizes, in_out_lens, batch_sizes):
3832
log_dir = os.path.join(log_dir, gpu_name, str(model_name))
3933
os.makedirs(log_dir, exist_ok=True)
4034

4135
model_class, model_path = model_to_class_and_path[model_name]
4236
kill_gpu_processes()
4337
for world_size in world_sizes:
44-
for in_len, out_len in in_out_lens:
38+
for in_len, out_len in in_out_lens:
4539
kill_gpu_processes()
4640
mode_str = "_".join(mode)
4741
log_file_name = f"{model_name}##{mode_str}##{world_size}##{in_len}##{out_len}##batch_size##.log"
4842
log_path = os.path.join(log_dir, log_file_name)
4943
print(log_path)
50-
test_model_inference(world_size,
51-
model_path,
52-
model_class,
53-
batch_sizes,
54-
in_len,
55-
out_len,
56-
mode,
57-
log_path)
44+
test_model_inference(world_size, model_path, model_class, batch_sizes, in_len, out_len, mode, log_path)
5845
log_md_file = log_dir + ".md"
5946
md_file = open(log_md_file, "w")
6047
# write head
61-
heads = ['mode', 'world_size', 'batch_size', 'input_len', 'output_len', 'prefill_cost', 'first_step_latency', 'last_step_latency', 'mean_latency',
62-
'prefill_throughput', 'decode_throughput', 'total_throughput',
63-
'card_num_per_qps']
48+
heads = [
49+
"mode",
50+
"world_size",
51+
"batch_size",
52+
"input_len",
53+
"output_len",
54+
"prefill_cost",
55+
"first_step_latency",
56+
"last_step_latency",
57+
"mean_latency",
58+
"prefill_throughput",
59+
"decode_throughput",
60+
"total_throughput",
61+
"card_num_per_qps",
62+
]
6463
md_file.write(f"test model: {model_name} \r\n")
65-
md_file.write('|')
64+
md_file.write("|")
6665
for head in heads:
6766
md_file.write(head + "|")
68-
md_file.write('\r\n')
69-
md_file.write('|')
67+
md_file.write("\r\n")
68+
md_file.write("|")
7069
for _ in range(len(heads)):
71-
md_file.write('------|')
72-
md_file.write('\r\n')
70+
md_file.write("------|")
71+
md_file.write("\r\n")
7372
log_files = list(os.listdir(log_dir))
7473
sorted(log_files, key=lambda x: tuple(map(int, x.split("##")[2:6])))
7574
for log_file in log_files:
7675
_, mode, world_size, input_len, output_len, batch_size, _ = log_file.split("##")
77-
fp_file = open(os.path.join(log_dir, log_file), "r")
76+
fp_file = open(os.path.join(log_dir, log_file), "r")
7877
all_lines = fp_file.readlines()
7978
fp_file.close()
8079
if len(all_lines) <= 2:
@@ -84,30 +83,50 @@ def test_all_setting(gpu_name, model_name, mode, log_dir, world_sizes, in_out_le
8483
laststep_cost = float(all_lines[-2].split(":")[1].strip())
8584
all_step_cost = float(all_lines[-1].split(":")[1].strip())
8685
mean_step_cost = (all_step_cost - prefill_cost) / float(output_len)
87-
card_num_per_qps = float(world_size) / (float(batch_size) / (all_step_cost / 1000))
86+
card_num_per_qps = float(world_size) / (float(batch_size) / (all_step_cost / 1000))
8887
prefill_throughput = float(batch_size) * float(input_len) / (prefill_cost / 1000)
8988
decode_throughput = float(batch_size) * float(output_len) / ((all_step_cost - prefill_cost) / 1000)
9089
total_throughput = float(batch_size) * (float(input_len) + float(output_len)) / (all_step_cost / 1000)
91-
md_file.write('|')
92-
infos = [mode, world_size, batch_size, input_len, output_len, prefill_cost, firststep_cost, laststep_cost, mean_step_cost,
93-
prefill_throughput, decode_throughput, total_throughput,
94-
card_num_per_qps]
90+
md_file.write("|")
91+
infos = [
92+
mode,
93+
world_size,
94+
batch_size,
95+
input_len,
96+
output_len,
97+
prefill_cost,
98+
firststep_cost,
99+
laststep_cost,
100+
mean_step_cost,
101+
prefill_throughput,
102+
decode_throughput,
103+
total_throughput,
104+
card_num_per_qps,
105+
]
95106
for info in infos:
96107
md_file.write(str(format(info, ".4f")) if isinstance(info, float) else str(info))
97108
md_file.write("|")
98-
md_file.write('\r\n')
109+
md_file.write("\r\n")
99110
md_file.close()
100111

101112

102113
gpu_name = "A800"
103-
in_out_lens = [(128, 128), (256, 256)] # in_out_lens 中的数据必须以从短到长的顺序排列,否则可能有问题。
104-
batch_sizes = [1, 2] # batch_sizes 中的数字也必须从小到大排列。
114+
in_out_lens = [(128, 128), (256, 256)] # in_out_lens 中的数据必须以从短到长的顺序排列,否则可能有问题。
115+
batch_sizes = [1, 2] # batch_sizes 中的数字也必须从小到大排列。
116+
117+
118+
if __name__ == "__main__":
119+
import torch
105120

121+
torch.multiprocessing.set_start_method("spawn")
106122

107-
test_all_setting(gpu_name,
108-
"llama-7b",
109-
mode=["triton_int8weight", "ppl_fp16_flashdecoding"], # mode 为 【】 为普通 fp16 的格式。
110-
log_dir="./",
111-
world_sizes=[1],
112-
in_out_lens=in_out_lens,
113-
batch_sizes=batch_sizes)
123+
test_all_setting(
124+
gpu_name,
125+
"llama3-8b",
126+
# mode=["triton_int8weight", "ppl_fp16_flashdecoding"], # mode 为 【】 为普通 fp16 的格式。
127+
mode=["triton_gqa_flashdecoding"],
128+
log_dir="./",
129+
world_sizes=[1],
130+
in_out_lens=in_out_lens,
131+
batch_sizes=batch_sizes,
132+
)

0 commit comments

Comments
 (0)