Skip to content

Commit 4a4515d

Browse files
committed
config server add alloc multimodal token interface.
1 parent 3bd5aab commit 4a4515d

File tree

2 files changed

+59
-14
lines changed

2 files changed

+59
-14
lines changed

lightllm/server/config_server/api_http.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@
1818
global_req_id = 0
1919
global_req_id_lock = Lock()
2020

21+
# This is a global ID for multimodal embedding, starting from 100000000
22+
global_multimodal_embedding_id = 100000000
23+
global_multimodal_embedding_id_lock = Lock()
24+
2125

2226
@app.get("/liveness")
2327
@app.post("/liveness")
@@ -94,3 +98,17 @@ async def allocate_global_id_range():
9498
end_id = global_req_id
9599

96100
return {"start_id": start_id, "end_id": end_id}
101+
102+
103+
@app.get("/allocate_global_unique_multimodal_id_range")
104+
async def allocate_global_unique_multimodal_id_range():
105+
global global_multimodal_embedding_id
106+
range_size = 8000000
107+
with global_multimodal_embedding_id_lock:
108+
if global_multimodal_embedding_id + range_size > 2 ** 63 - 1:
109+
global_multimodal_embedding_id = 100000000
110+
start_id = global_multimodal_embedding_id
111+
global_multimodal_embedding_id += range_size
112+
end_id = global_multimodal_embedding_id
113+
114+
return {"start_id": start_id, "end_id": end_id}

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
import uuid
22
import threading
33
import dataclasses
4+
import requests
45
from ..interface import CacheManager, CacheManagerFactory
56
from typing import Union
67
import torch
78
import time
89
from collections import deque
910
import multiprocessing.shared_memory as shm
1011
from ..utils import get_shm_name_data, get_shm_name_embed, free_shm
12+
from lightllm.utils.log_utils import init_logger
13+
14+
logger = init_logger(__name__)
1115

1216

1317
@dataclasses.dataclass
@@ -22,10 +26,11 @@ class Record(object):
2226
token_id: int
2327
token_num: int
2428

29+
2530
@CacheManagerFactory.register("naive")
2631
class InMemoryCache(CacheManager):
27-
2832
def __init__(self, args) -> None:
33+
self.args = args
2934
self._records = dict()
3035
self._md5_to_record = dict()
3136
self.capacity = max(1, args.cache_capacity)
@@ -34,12 +39,37 @@ def __init__(self, args) -> None:
3439
self.occupied = 0
3540
self.expired_secs = 60 * 60
3641
self.lock = threading.Lock()
37-
38-
from lightllm.server.tokenizer import get_tokenizer
39-
tokenizer = get_tokenizer(
40-
args.model_dir, args.tokenizer_mode, trust_remote_code=args.trust_remote_code
41-
)
42-
self.cur_token_id = tokenizer.vocab_size + 10000
42+
self.token_id_range_start = 0
43+
self.token_id_range_end = 0
44+
self.use_config_server = self.args.config_server_host and self.args.config_server_port
45+
46+
def _check_and_set_new_id_range(self, alloced_token_num):
47+
need_update_range = self.token_id_range_start + alloced_token_num >= self.token_id_range_end
48+
if need_update_range:
49+
if not self.use_config_server:
50+
self.token_id_range_start = 100000000
51+
self.token_id_range_end = 2 ** 63 - 1
52+
else:
53+
while True:
54+
try:
55+
config_server_ip_port = f"{self.args.config_server_host}:{self.args.config_server_port}"
56+
url = f"http://{config_server_ip_port}/allocate_global_unique_multimodal_id_range"
57+
response = requests.get(url)
58+
if response.status_code == 200:
59+
id_range = response.json()
60+
logger.info(f"get new multimodal id range {id_range}")
61+
self.token_id_range_start = id_range["start_id"]
62+
self.token_id_range_end = id_range["end_id"]
63+
assert (
64+
self.token_id_range_start + alloced_token_num < self.token_id_range_end
65+
), f"get multimodal id range error {self.token_id_range_start} {self.token_id_range_end}"
66+
return
67+
else:
68+
raise RuntimeError(f"Failed to fetch ID range from config server: {response.status_code}")
69+
except BaseException as e:
70+
logger.exception(str(e))
71+
time.sleep(3)
72+
return
4373

4474
def _clear(self):
4575
deleted = 0
@@ -73,6 +103,7 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
73103

74104
id = uuid.uuid1()
75105
id = id.int
106+
self._check_and_set_new_id_range(token_num)
76107
record = Record(
77108
id=id,
78109
md5sum=md5sum,
@@ -81,10 +112,10 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
81112
embed=False,
82113
createtime=t,
83114
visittime=t,
84-
token_id=self.cur_token_id,
115+
token_id=self.token_id_range_start,
85116
token_num=token_num,
86117
)
87-
self.cur_token_id += token_num
118+
self.token_id_range_start += token_num
88119
self._records[id] = record
89120
self._md5_to_record[md5sum] = record
90121
self.occupied += 1
@@ -95,11 +126,7 @@ def alloc(self, md5sum: str, token_num: int) -> dict:
95126
record.visittime = t
96127
record.ref += 1
97128

98-
return {
99-
"id": record.id,
100-
"token_id": record.token_id,
101-
"token_num": record.token_num
102-
}
129+
return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}
103130

104131
def release(self, id: int) -> None:
105132
with self.lock:

0 commit comments

Comments
 (0)