Skip to content

Commit 0b95787

Browse files
authored
update test (#871)
Co-authored-by: baishihao <baishihao@sensetime.com>
1 parent e023d03 commit 0b95787

File tree

2 files changed

+43
-16
lines changed

2 files changed

+43
-16
lines changed

test/benchmark_qps.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def get_tokenizer(
3434
def get_random_length(reqs_num: int, length: int, range_ratio: float) -> List[int]:
3535
lens = []
3636
lens = np.random.randint(
37-
int(length * range_ratio),
37+
max(int(length * range_ratio), 1),
3838
length + 1,
3939
size=reqs_num,
4040
)
@@ -64,6 +64,21 @@ def gen_random_data(
6464
return prompts, output_lens
6565

6666

67+
def get_custom_input_data(data_path, output_len, tokenizer, range_ratio):
68+
prompts = []
69+
with open(data_path, "r") as f:
70+
for line in f.readlines():
71+
data_line = json.loads(line)
72+
input_data = tokenizer.apply_chat_template(
73+
data_line["messages"], add_generation_prompt=True, tokenize=False
74+
)
75+
input_len = len(tokenizer.encode(input_data))
76+
prompts.append([input_data, input_len])
77+
output_lens = get_random_length(len(prompts), output_len, range_ratio)
78+
print("Load random data finish.")
79+
return prompts, output_lens
80+
81+
6782
model_name = []
6883

6984

@@ -115,13 +130,13 @@ async def async_post_stream_lightllm(url, prompt, max_new_tokens, session):
115130
"do_sample": False,
116131
"ignore_eos": True,
117132
"max_new_tokens": max_new_tokens,
133+
"add_special_tokens": False,
118134
},
119135
}
120136
headers = {"Content-Type": "application/json"}
121137
used_time = []
122138
start_time = time.time()
123139
last_time = start_time
124-
125140
async with session.post(url, headers=headers, json=data) as response:
126141
if response.status != 200:
127142
return []
@@ -189,12 +204,11 @@ async def response_collector(
189204
result, input_len = await task
190205
request_queue.task_done()
191206
assert result is not None
192-
if len(result) > 1 and not stop_send.is_set():
207+
if len(result) >= 1 and not stop_send.is_set():
193208
results.append((result, input_len))
194209
current_count = counter[0] + 1
195210
counter[0] = current_count
196211
print(f"\rfinished_reqs:{current_count} / target_reqs:{reqs_num} / sent_reqs:{sent_count[0]}", end="")
197-
198212
if len(results) >= reqs_num and not stop_send.is_set():
199213
end_time[0] = time.time()
200214
print("\nReached target number of responses")
@@ -292,6 +306,7 @@ def main():
292306
)
293307
parser.add_argument("--num_clients", type=int, default=100)
294308
parser.add_argument("--tokenizer_path", type=str, default=None)
309+
parser.add_argument("--data_path", type=str, default=None)
295310
parser.add_argument("--input_num", type=int, default=2000)
296311
parser.add_argument("--input_qps", type=float, default=30.0)
297312
parser.add_argument("--input_len", type=int, default=1024)
@@ -323,17 +338,21 @@ def main():
323338

324339
assert args.tokenizer_path is not None
325340
model_name.append(args.tokenizer_path)
326-
seed_all(args.seed)
341+
# seed_all(args.seed)
327342
url = args.url
328343
tokenizer = get_tokenizer(args.tokenizer_path)
329-
# qps发送模式发送请求的数量不固定,这里暂定为input_num的10倍
330-
prompts, max_new_tokens = gen_random_data(
331-
args.input_len,
332-
args.output_len,
333-
args.input_num if not args.continuous_send else 10 * args.input_num,
334-
tokenizer,
335-
args.range_ratio,
336-
)
344+
if args.data_path is not None:
345+
prompts, max_new_tokens = get_custom_input_data(args.data_path, args.output_len, tokenizer, args.range_ratio)
346+
args.input_num = len(prompts)
347+
else:
348+
# qps发送模式发送请求的数量不固定,这里暂定为input_num的10倍
349+
prompts, max_new_tokens = gen_random_data(
350+
args.input_len,
351+
args.output_len,
352+
args.input_num if not args.continuous_send else 10 * args.input_num,
353+
tokenizer,
354+
args.range_ratio,
355+
)
337356

338357
percentiles = [25, 50, 75, 90, 95, 99, 100]
339358
if args.server_api == "lightllm":
@@ -364,7 +383,7 @@ def main():
364383
)
365384
)
366385
loop.close()
367-
386+
print(len(results))
368387
first_token_time = []
369388
decode_token_time = []
370389
request_time = []
@@ -379,6 +398,13 @@ def main():
379398
final_output_lens.append(len(result))
380399
input_lens.append(input_len)
381400
valid_num += 1
401+
else:
402+
first_token_time.append(result[0])
403+
decode_token_time.append(0) # no decode
404+
request_time.append(sum(result))
405+
final_output_lens.append(len(result))
406+
input_lens.append(input_len)
407+
valid_num += 1
382408

383409
print(
384410
f"\n\nvalid num = {valid_num}; all data num = {len(results)}; valid ratio = {valid_num * 1.0 / len(results)}\n"

test/model/model_infer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def test_model_inference(args, model_class):
2020
workers = []
2121
dp_size = args.get("dp", 1)
2222

23-
for rank_id in range(args.tp):
23+
for rank_id in range(args.node_rank * 8, (args.node_rank + 1) * 8):
2424
model_kvargs = {
2525
"args": args,
2626
"nccl_host": args.nccl_host,
@@ -88,6 +88,7 @@ def overlap_prefill(
8888
_0_b_start_loc,
8989
_0_b_seq_len,
9090
_o_b_ready_cache_len,
91+
{},
9192
)
9293

9394
_1_batch_size = batch_size - batch_size // 2
@@ -110,6 +111,7 @@ def overlap_prefill(
110111
_1_b_start_loc,
111112
_1_b_seq_len,
112113
_1_b_ready_cache_len,
114+
{},
113115
)
114116

115117
logits, logits1 = model_part.microbatch_overlap_prefill(micro_batch1, micro_batch2)
@@ -213,7 +215,6 @@ def tppart_model_infer(args, model_class, model_kvargs, batch_size, input_len, o
213215

214216
if model_class == Deepseek2TpPartModel:
215217
model_cfg, _ = PretrainedConfig.get_config_dict(model_kvargs["weight_dir"])
216-
dist_group_manager.new_deepep_group(model_cfg["n_routed_experts"])
217218
dist.barrier()
218219

219220
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)