11import uuid
22import threading
33import dataclasses
4+ import requests
45from ..interface import CacheManager , CacheManagerFactory
56from typing import Union
67import torch
78import time
89from collections import deque
910import multiprocessing .shared_memory as shm
1011from ..utils import get_shm_name_data , get_shm_name_embed , free_shm
12+ from lightllm .utils .log_utils import init_logger
13+
14+ logger = init_logger (__name__ )
1115
1216
1317@dataclasses .dataclass
@@ -22,10 +26,11 @@ class Record(object):
2226 token_id : int
2327 token_num : int
2428
29+
2530@CacheManagerFactory .register ("naive" )
2631class InMemoryCache (CacheManager ):
27-
2832 def __init__ (self , args ) -> None :
33+ self .args = args
2934 self ._records = dict ()
3035 self ._md5_to_record = dict ()
3136 self .capacity = max (1 , args .cache_capacity )
@@ -34,12 +39,37 @@ def __init__(self, args) -> None:
3439 self .occupied = 0
3540 self .expired_secs = 60 * 60
3641 self .lock = threading .Lock ()
37-
38- from lightllm .server .tokenizer import get_tokenizer
39- tokenizer = get_tokenizer (
40- args .model_dir , args .tokenizer_mode , trust_remote_code = args .trust_remote_code
41- )
42- self .cur_token_id = tokenizer .vocab_size + 10000
42+ self .token_id_range_start = 0
43+ self .token_id_range_end = 0
44+ self .use_config_server = self .args .config_server_host and self .args .config_server_port
45+
46+ def _check_and_set_new_id_range (self , alloced_token_num ):
47+ need_update_range = self .token_id_range_start + alloced_token_num >= self .token_id_range_end
48+ if need_update_range :
49+ if not self .use_config_server :
50+ self .token_id_range_start = 100000000
51+ self .token_id_range_end = 2 ** 63 - 1
52+ else :
53+ while True :
54+ try :
55+ config_server_ip_port = f"{ self .args .config_server_host } :{ self .args .config_server_port } "
56+ url = f"http://{ config_server_ip_port } /allocate_global_unique_multimodal_id_range"
57+ response = requests .get (url )
58+ if response .status_code == 200 :
59+ id_range = response .json ()
60+ logger .info (f"get new multimodal id range { id_range } " )
61+ self .token_id_range_start = id_range ["start_id" ]
62+ self .token_id_range_end = id_range ["end_id" ]
63+ assert (
64+ self .token_id_range_start + alloced_token_num < self .token_id_range_end
65+ ), f"get multimodal id range error { self .token_id_range_start } { self .token_id_range_end } "
66+ return
67+ else :
68+ raise RuntimeError (f"Failed to fetch ID range from config server: { response .status_code } " )
69+ except BaseException as e :
70+ logger .exception (str (e ))
71+ time .sleep (3 )
72+ return
4373
4474 def _clear (self ):
4575 deleted = 0
@@ -73,6 +103,7 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
73103
74104 id = uuid .uuid1 ()
75105 id = id .int
106+ self ._check_and_set_new_id_range (token_num )
76107 record = Record (
77108 id = id ,
78109 md5sum = md5sum ,
@@ -81,10 +112,10 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
81112 embed = False ,
82113 createtime = t ,
83114 visittime = t ,
84- token_id = self .cur_token_id ,
115+ token_id = self .token_id_range_start ,
85116 token_num = token_num ,
86117 )
87- self .cur_token_id += token_num
118+ self .token_id_range_start += token_num
88119 self ._records [id ] = record
89120 self ._md5_to_record [md5sum ] = record
90121 self .occupied += 1
@@ -95,11 +126,7 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
95126 record .visittime = t
96127 record .ref += 1
97128
98- return {
99- "id" : record .id ,
100- "token_id" : record .token_id ,
101- "token_num" : record .token_num
102- }
129+ return {"id" : record .id , "token_id" : record .token_id , "token_num" : record .token_num }
103130
104131 def release (self , id : int ) -> None :
105132 with self .lock :
0 commit comments