Skip to content

Commit 52f5f6d

Browse files
committed
update test
1 parent e032bf7 commit 52f5f6d

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

test/model/model_infer_mtp.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def test_model_inference_mtp(args):
7575
"graph_max_batch_size": args.graph_max_batch_size,
7676
"mem_faction": args.mem_fraction,
7777
"max_req_num": 2000,
78-
"batch_max_tokens": 16384,
78+
"batch_max_tokens": 2048,
7979
"run_mode": "normal",
8080
"max_seq_length": args.max_req_total_len,
8181
"spec_algo": args.spec_algo,
@@ -110,7 +110,7 @@ def torch_profile(fn, log_dir=None):
110110
print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10))
111111

112112

113-
def run_forward_once(input_len, output_len, batch_size, main_model, draft_models, warmup=False):
113+
def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False):
114114
import time
115115

116116
torch.cuda.synchronize()
@@ -166,7 +166,9 @@ def run_forward_once(input_len, output_len, batch_size, main_model, draft_models
166166
prefill_end_time = time.time()
167167
if get_current_rank_in_dp() == 0 and not warmup:
168168
print("prefill time cost:", (prefill_end_time - prefill_start_time) * 1000)
169-
print(f"Prefill throughput: {batch_size * input_len / (prefill_end_time - prefill_start_time)} tokens/s")
169+
print(
170+
f"Prefill throughput: {batch_size * input_len * args.dp / (prefill_end_time - prefill_start_time)} tokens/s"
171+
)
170172

171173
torch.cuda.synchronize()
172174

@@ -240,7 +242,7 @@ def run_forward_once(input_len, output_len, batch_size, main_model, draft_models
240242
if get_current_rank_in_dp() == 0 and not warmup:
241243
step_time = step_end_time - step_start_time
242244
print(i, " step cost time:", step_time * 1000)
243-
print(f"Decode throughput: {batch_size * (len(draft_models) + 1) / step_time} tokens/s")
245+
print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s")
244246

245247
main_model.mem_manager.free_all()
246248
main_model.req_manager.free_all()
@@ -273,9 +275,9 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a
273275

274276
for batch_size in batch_sizes:
275277
# warm up
276-
run_forward_once(input_len, output_len, batch_size, main_model, draft_models, warmup=True)
278+
run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=True)
277279
torch.cuda.synchronize()
278-
run_forward_once(input_len, output_len, batch_size, main_model, draft_models, warmup=False)
280+
run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False)
279281
dist.barrier()
280282

281283
ans_queue.put(True)

0 commit comments

Comments
 (0)