Skip to content

Commit 2f96b1a

Browse files
committed
[fix]fix rpyc in multimodal process
1 parent 780a57f commit 2f96b1a

File tree

8 files changed

+52
-78
lines changed

8 files changed

+52
-78
lines changed

lightllm/models/whisper/whisper_audio.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,14 @@ def encode(self, audio_items: List[AudioItem]):
190190
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191191
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192192

193-
for i in range(len(uuids)):
194-
if not self.cache_client.root.get_item_embed(uuids[i]):
195-
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
196-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
197-
self.cache_client.root.set_item_embed(uuids[i])
193+
ready_audio = self.cache_client.root.get_items_data(uuids)
194+
ids_to_set = []
195+
for i, ready in enumerate(ready_audio):
196+
if ready:
197+
continue
198+
uid = uuids[i]
199+
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
200+
create_shm(get_shm_name_data(uid), cur_embed_bytes)
201+
ids_to_set.append(uid)
202+
if ids_to_set:
203+
self.cache_client.root.set_items_data(ids=ids_to_set)

lightllm/server/audioserver/manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ async def loop_for_fwd(self):
9494

9595
multimodal_params = group_req_indexes.multimodal_params
9696

97-
for audio in multimodal_params.audios:
98-
if not self.cache_client.root.get_item_embed(audio.uuid):
97+
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
98+
ready_audio = self.cache_client.root.get_items_embed(audio_uuids)
99+
100+
for audio, ready in zip(multimodal_params.audios, ready_audio):
101+
if not ready:
99102
audios_need_infer.append(audio)
100103

101104
if len(audios_need_infer) == self.infer_batch_size:

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -89,52 +89,12 @@ def _clear(self):
8989
if deleted >= max_delete:
9090
break
9191

92-
# def alloc(self, md5sum: str, token_num: int) -> dict:
93-
# with self.lock:
94-
# t = time.time()
95-
# # add new record
96-
# if md5sum not in self._md5_to_record:
97-
98-
# # full, need to clear some unused items
99-
# if self.occupied >= self.capacity:
100-
# self._clear()
101-
# if self.occupied >= self.capacity:
102-
# return None
103-
104-
# id = uuid.uuid1()
105-
# id = id.int
106-
# self._check_and_set_new_id_range(token_num)
107-
# record = Record(
108-
# id=id,
109-
# md5sum=md5sum,
110-
# ref=1,
111-
# data=False,
112-
# embed=False,
113-
# createtime=t,
114-
# visittime=t,
115-
# token_id=self.token_id_range_start,
116-
# token_num=token_num,
117-
# )
118-
# self.token_id_range_start += token_num
119-
# self._records[id] = record
120-
# self._md5_to_record[md5sum] = record
121-
# self.occupied += 1
122-
123-
# # cache hit
124-
# else:
125-
# record = self._md5_to_record[md5sum]
126-
# record.visittime = t
127-
# record.ref += 1
128-
129-
# return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}
130-
13192
def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[dict]:
13293
results = []
13394
with self.lock:
13495
for md5, tnum in zip(md5_list, token_num_list):
13596
t = time.time()
13697
if md5 not in self._md5_to_record:
137-
# 若不存在则分配新记录(与alloc逻辑相同)
13898
if self.occupied >= self.capacity:
13999
self._clear()
140100
if self.occupied >= self.capacity:
@@ -158,7 +118,6 @@ def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[di
158118
self._md5_to_record[md5] = record
159119
self.occupied += 1
160120
else:
161-
# 缓存命中,更新引用计数和访问时间
162121
record = self._md5_to_record[md5]
163122
record.visittime = t
164123
record.ref += 1
@@ -175,17 +134,20 @@ def release(self, id: int) -> None:
175134
# def get_item_data(self, id: int) -> bool:
176135
# return self._records[id].data
177136

137+
def set_items_data(self, ids: list[int]) -> None:
138+
with self.lock:
139+
for id in ids:
140+
self._records[id].data = True
141+
178142
def get_items_data(self, ids: list[int]) -> list[bool]:
179143
with self.lock:
180144
return [self._records.get(i).data if i in self._records else False for i in ids]
181145

182-
def set_items_data(self, ids: list[int]) -> None:
146+
def set_items_embed(self, ids: list[int]) -> None:
183147
with self.lock:
184-
for i in ids:
185-
self._records[i].data = True
148+
for id in ids:
149+
self._records[id].embed = True
186150

187-
def set_item_embed(self, id: int) -> None:
188-
self._records[id].embed = True
189-
190-
def get_item_embed(self, id: int) -> bool:
191-
return self._records[id].embed
151+
def get_items_embed(self, ids: list[int]) -> list[bool]:
152+
with self.lock:
153+
return [self._records.get(i).embed if i in self._records else False for i in ids]

lightllm/server/embed_cache/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def set_items_data(self, ids: list[int]) -> None:
1919
def get_items_data(self, ids: list[int]) -> list[bool]:
2020
pass
2121

22-
def set_item_embed(self, id: int) -> None:
22+
def set_items_embed(self, ids: list[int]) -> None:
2323
pass
2424

25-
def get_item_embed(self, id: int) -> bool:
25+
def get_items_embed(self, ids: list[int]) -> list[bool]:
2626
pass
2727

2828

lightllm/server/embed_cache/manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def on_disconnect(self, conn):
2525
def exposed_alloc_batch(self, md5sum_list: list[str], token_num_list: list[int]) -> dict:
2626
md5sum_list = obtain(md5sum_list)
2727
token_num_list = obtain(token_num_list)
28-
record = self._impl.alloc(md5sum_list, token_num_list)
28+
record = self._impl.alloc_batch(md5sum_list, token_num_list)
2929
return record
3030

3131
def exposed_release(self, id: int) -> None:
@@ -40,13 +40,13 @@ def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
4040
ids = obtain(ids)
4141
return self._impl.get_items_data(ids=ids)
4242

43-
def exposed_set_item_embed(self, id: int) -> None:
44-
id = obtain(id)
45-
return self._impl.set_item_embed(id=id)
43+
def exposed_set_items_embed(self, ids: list[int]) -> None:
44+
ids = obtain(ids)
45+
return self._impl.set_items_embed(ids=ids)
4646

47-
def exposed_get_item_embed(self, id: int) -> bool:
48-
id = obtain(id)
49-
return self._impl.get_item_embed(id=id)
47+
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
48+
ids = obtain(ids)
49+
return self._impl.get_items_embed(ids=ids)
5050

5151

5252
def start_cache_manager(port: int, args, pipe_writer):

lightllm/server/httpserver/manager.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282

8383
self.enable_multimodal = enable_multimodal
8484
if self.enable_multimodal:
85-
self.cache_client = rpyc.connect("localhost", cache_port, onfig={"allow_pickle": True})
85+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
8686
self.send_to_visual = context.socket(zmq.PUSH)
8787
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
8888

@@ -194,7 +194,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
194194
item.token_num = record["token_num"]
195195
if not ready:
196196
create_shm(get_shm_name_data(item.uuid), data)
197-
self.cache_client.root.set_items_data(item.uuid)
198197
uids_to_write.append(item.uuid)
199198
if uids_to_write:
200199
self.cache_client.root.set_items_data(uids_to_write)

lightllm/server/visualserver/manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535

3636
self.recv_from_httpserver = context.socket(zmq.PULL)
3737
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}")
38-
self.cache_client = rpyc.connect("localhost", cache_port)
38+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
3939
self.cache_port = cache_port
4040
self.waiting_reqs: List[GroupReqIndexes] = []
4141
self.model_weightdir = args.model_dir
@@ -121,11 +121,9 @@ async def loop_for_fwd(self):
121121
multimodal_params = group_req_indexes.multimodal_params
122122

123123
img_uuids = [img.uuid for img in multimodal_params.images]
124-
ready_flags = []
125-
for uuid in img_uuids:
126-
ready_flags.append(self.cache_client.root.get_items_embed(uuid))
124+
ready_image = self.cache_client.root.get_items_embed(img_uuids)
127125

128-
for img, ready in zip(multimodal_params.images, ready_flags):
126+
for img, ready in zip(multimodal_params.images, ready_image):
129127
if not ready:
130128
images_need_infer.append(img)
131129

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,20 @@ def exposed_encode(self, images: List[ImageItem]):
9494
images = obtain(images)
9595
all_img_embeds, uuids, valid_ids = self.forward(images)
9696
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
97+
9798
if self.tp_rank_id == 0:
98-
for i in range(len(uuids)):
99+
ready_flags = self.cache_client.root.get_items_embed(uuids)
100+
ids_to_set = []
101+
for i, ready in enumerate(ready_flags):
102+
if ready:
103+
continue
99104
uid = uuids[i]
100-
if not self.cache_client.root.get_item_embed(uid):
101-
start, end = valid_ids[i]
102-
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
103-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
104-
self.cache_client.root.set_item_embed(uuids[i])
105+
start, end = valid_ids[i]
106+
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
107+
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
108+
ids_to_set.append(uid)
109+
if ids_to_set:
110+
self.cache_client.root.set_items_embed(ids_to_set)
105111
return
106112

107113

0 commit comments

Comments
 (0)