Skip to content

Commit 5e7f2d9

Browse files
committed
Fix repeatedly update_defs, disable -> enable
1 parent 1741884 commit 5e7f2d9

File tree

4 files changed

+8
-11
lines changed

4 files changed

+8
-11
lines changed

lightllm/server/api_cli.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ def make_argument_parser() -> argparse.ArgumentParser:
538538
"--disk_cache_storage_size", type=float, default=10, help="""The capacity of disk cache. GB used."""
539539
)
540540
parser.add_argument(
541-
"--disable_dp_prompt_cache_fetch",
541+
"--enable_dp_prompt_cache_fetch",
542542
action="store_true",
543543
default=False,
544-
help="""Disable prefix prompt cache fetch for data parallel inference, enabled by default.""",
544+
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
545545
)
546546
return parser

lightllm/server/api_start.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -270,11 +270,11 @@ def normal_or_p_d_start(args):
270270

271271
send_and_receive_node_ip(args) # 多机用于收发node ip
272272
# PD 分离模式下必须禁用 DP prompt cache fetch,且 dp 必须 > 1
273-
if not args.disable_dp_prompt_cache_fetch and args.dp <= 1:
274-
args.disable_dp_prompt_cache_fetch = True
273+
if args.enable_dp_prompt_cache_fetch and args.dp <= 1:
274+
args.enable_dp_prompt_cache_fetch = False
275275
logger.warning(
276276
"""dp <= 1 does not support dp_prompt_cache_fetch;
277-
overriding disable_dp_prompt_cache_fetch to True"""
277+
overriding enable_dp_prompt_cache_fetch to False"""
278278
)
279279

280280
set_env_start_args(args)

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class StartArgs:
111111
cpu_cache_token_page_size: int = field(default=64)
112112
enable_disk_cache: bool = field(default=False)
113113
disk_cache_storage_size: float = field(default=10)
114-
disable_dp_prompt_cache_fetch: bool = field(default=False)
114+
enable_dp_prompt_cache_fetch: bool = field(default=False)
115115
# zmp ports
116116
router_port: int = field(default=None)
117117
detokenization_port: int = field(default=None)

lightllm/server/router/model_infer/mode_backend/dp_backend/impl.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No
3838

3939
# 用于控制每一步是执行prefill 和 decode 还是跳过
4040
self.control_state_machine = DPControlState(backend=self)
41-
self.disable_dp_prompt_cache_fetch = get_env_start_args().disable_dp_prompt_cache_fetch
41+
self.enable_dp_prompt_cache_fetch = get_env_start_args().enable_dp_prompt_cache_fetch
4242
self.min_trans_token_num = min_trans_token_num
4343

4444
# 在 mtp 模式下切换绑定的prefill 和 decode 函数
@@ -72,9 +72,6 @@ def __init__(self, mem_queue: mp.Queue, mem_queues: List[mp.Queue] = None) -> No
7272
return
7373

7474
def init_custom(self):
75-
self.enable_dp_prompt_cache_fetch = (
76-
not self.disable_dp_prompt_cache_fetch and self.dp_size_in_node > 1 and self.mem_queues is not None
77-
)
7875
if self.enable_dp_prompt_cache_fetch:
7976
torch.cuda.set_device(get_current_device_id())
8077

@@ -116,7 +113,7 @@ def _match_radix_cache(self, shm_req):
116113
input_token_ids = shm_req.shm_prompt_ids.arr[0 : shm_req.input_len]
117114
key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu")
118115
key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值
119-
_, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True)
116+
_, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=False)
120117
return kv_len, value_tensor
121118

122119
def _post_init_reqs(self, infer_reqs: List[InferReq], other_reqs: List[Tuple] = []):

0 commit comments

Comments
 (0)