Skip to content

Commit 2f7f6e4

Browse files
committed
fix
1 parent 5dcea99 commit 2f7f6e4

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

lightllm/server/api_start.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,25 @@ def normal_or_p_d_start(args):
9191
if args.graph_max_len_in_batch == 0:
9292
args.graph_max_len_in_batch = args.max_req_total_len
9393

94-
# 这些模式不能同时设置。
95-
assert [
96-
args.disable_chunked_prefill,
97-
args.diverse_mode,
98-
args.use_reward_model,
99-
args.return_all_prompt_logprobs,
100-
].count(True) <= 1
101-
102-
# chuncked prefill 需要和 dynamic_prompt_cache 一起使能
94+
# mode setting check.
10395
if not args.disable_chunked_prefill:
10496
assert args.disable_dynamic_prompt_cache is False
97+
assert args.disable_chunked_prefill is False
10598
if args.output_constraint_mode != "none":
10699
assert args.disable_dynamic_prompt_cache is False
100+
assert args.disable_chunked_prefill is False
107101
if args.token_healing_mode:
108102
assert args.disable_dynamic_prompt_cache is False
103+
assert args.disable_chunked_prefill is False
104+
if args.diverse_mode:
105+
assert args.disable_dynamic_prompt_cache is False
106+
assert args.disable_chunked_prefill is False
107+
if args.use_reward_model:
108+
assert args.disable_dynamic_prompt_cache is True, f"need add --disable_dynamic_prompt_cache"
109+
assert args.disable_chunked_prefill is True, f"need add --disable_chunked_prefill"
110+
if args.return_all_prompt_logprobs:
111+
assert args.disable_dynamic_prompt_cache is True, f"need add --disable_dynamic_prompt_cache"
112+
assert args.disable_chunked_prefill is True, f"need add --disable_chunked_prefill"
109113

110114
# 部分模式还不能支持与高级动态调度算法协同,to do.
111115
if args.diverse_mode:

lightllm/server/embed_cache/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
import multiprocessing.shared_memory as shm
55

66

7-
def tensor2bytes(t):
7+
def tensor2bytes(t:torch.Tensor):
88
# t = t.cpu().numpy().tobytes()
99
# return t
1010
buf = BytesIO()
11-
torch.save(t, buf)
11+
torch.save(t.detach().cpu(), buf)
1212
buf.seek(0)
1313
return buf.read()
1414

0 commit comments

Comments
 (0)