Skip to content

Commit 803e095

Browse files
committed
[fix]fix rpyc in multimodal process
1 parent 780a57f commit 803e095

File tree

8 files changed

+51
-89
lines changed

8 files changed

+51
-89
lines changed

lightllm/models/whisper/whisper_audio.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,14 @@ def encode(self, audio_items: List[AudioItem]):
190190
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191191
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192192

193-
for i in range(len(uuids)):
194-
if not self.cache_client.root.get_item_embed(uuids[i]):
195-
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
196-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
197-
self.cache_client.root.set_item_embed(uuids[i])
193+
ready_audio = self.cache_client.root.get_items_data(uuids)
194+
ids_to_set = []
195+
for i, ready in enumerate(ready_audio):
196+
if ready:
197+
continue
198+
uid = uuids[i]
199+
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
200+
create_shm(get_shm_name_data(uid), cur_embed_bytes)
201+
ids_to_set.append(uid)
202+
if ids_to_set:
203+
self.cache_client.root.set_items_data(ids=ids_to_set)

lightllm/server/audioserver/manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ async def loop_for_fwd(self):
9494

9595
multimodal_params = group_req_indexes.multimodal_params
9696

97-
for audio in multimodal_params.audios:
98-
if not self.cache_client.root.get_item_embed(audio.uuid):
97+
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
98+
ready_audio = self.cache_client.root.get_items_embed(audio_uuids)
99+
100+
for audio, ready in zip(multimodal_params.audios, ready_audio):
101+
if not ready:
99102
audios_need_infer.append(audio)
100103

101104
if len(audios_need_infer) == self.infer_batch_size:

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 10 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -89,52 +89,12 @@ def _clear(self):
8989
if deleted >= max_delete:
9090
break
9191

92-
# def alloc(self, md5sum: str, token_num: int) -> dict:
93-
# with self.lock:
94-
# t = time.time()
95-
# # add new record
96-
# if md5sum not in self._md5_to_record:
97-
98-
# # full, need to clear some unused items
99-
# if self.occupied >= self.capacity:
100-
# self._clear()
101-
# if self.occupied >= self.capacity:
102-
# return None
103-
104-
# id = uuid.uuid1()
105-
# id = id.int
106-
# self._check_and_set_new_id_range(token_num)
107-
# record = Record(
108-
# id=id,
109-
# md5sum=md5sum,
110-
# ref=1,
111-
# data=False,
112-
# embed=False,
113-
# createtime=t,
114-
# visittime=t,
115-
# token_id=self.token_id_range_start,
116-
# token_num=token_num,
117-
# )
118-
# self.token_id_range_start += token_num
119-
# self._records[id] = record
120-
# self._md5_to_record[md5sum] = record
121-
# self.occupied += 1
122-
123-
# # cache hit
124-
# else:
125-
# record = self._md5_to_record[md5sum]
126-
# record.visittime = t
127-
# record.ref += 1
128-
129-
# return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}
130-
13192
def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[dict]:
13293
results = []
13394
with self.lock:
13495
for md5, tnum in zip(md5_list, token_num_list):
13596
t = time.time()
13697
if md5 not in self._md5_to_record:
137-
# 若不存在则分配新记录(与alloc逻辑相同)
13898
if self.occupied >= self.capacity:
13999
self._clear()
140100
if self.occupied >= self.capacity:
@@ -158,7 +118,6 @@ def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[di
158118
self._md5_to_record[md5] = record
159119
self.occupied += 1
160120
else:
161-
# 缓存命中,更新引用计数和访问时间
162121
record = self._md5_to_record[md5]
163122
record.visittime = t
164123
record.ref += 1
@@ -169,23 +128,20 @@ def release(self, id: int) -> None:
169128
with self.lock:
170129
self._records[id].ref -= 1
171130

172-
# def set_item_data(self, id: int) -> None:
173-
# self._records[id].data = True
174-
175-
# def get_item_data(self, id: int) -> bool:
176-
# return self._records[id].data
131+
def set_items_data(self, ids: list[int]) -> None:
132+
with self.lock:
133+
for id in ids:
134+
self._records[id].data = True
177135

178136
def get_items_data(self, ids: list[int]) -> list[bool]:
179137
with self.lock:
180138
return [self._records.get(i).data if i in self._records else False for i in ids]
181139

182-
def set_items_data(self, ids: list[int]) -> None:
140+
def set_items_embed(self, ids: list[int]) -> None:
183141
with self.lock:
184-
for i in ids:
185-
self._records[i].data = True
142+
for id in ids:
143+
self._records[id].embed = True
186144

187-
def set_item_embed(self, id: int) -> None:
188-
self._records[id].embed = True
189-
190-
def get_item_embed(self, id: int) -> bool:
191-
return self._records[id].embed
145+
def get_items_embed(self, ids: list[int]) -> list[bool]:
146+
with self.lock:
147+
return [self._records.get(i).embed if i in self._records else False for i in ids]

lightllm/server/embed_cache/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def set_items_data(self, ids: list[int]) -> None:
1919
def get_items_data(self, ids: list[int]) -> list[bool]:
2020
pass
2121

22-
def set_item_embed(self, id: int) -> None:
22+
def set_items_embed(self, ids: list[int]) -> None:
2323
pass
2424

25-
def get_item_embed(self, id: int) -> bool:
25+
def get_items_embed(self, ids: list[int]) -> list[bool]:
2626
pass
2727

2828

lightllm/server/embed_cache/manager.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def on_disconnect(self, conn):
2525
def exposed_alloc_batch(self, md5sum_list: list[str], token_num_list: list[int]) -> dict:
2626
md5sum_list = obtain(md5sum_list)
2727
token_num_list = obtain(token_num_list)
28-
record = self._impl.alloc(md5sum_list, token_num_list)
28+
record = self._impl.alloc_batch(md5sum_list, token_num_list)
2929
return record
3030

3131
def exposed_release(self, id: int) -> None:
@@ -40,13 +40,13 @@ def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
4040
ids = obtain(ids)
4141
return self._impl.get_items_data(ids=ids)
4242

43-
def exposed_set_item_embed(self, id: int) -> None:
44-
id = obtain(id)
45-
return self._impl.set_item_embed(id=id)
43+
def exposed_set_items_embed(self, ids: list[int]) -> None:
44+
ids = obtain(ids)
45+
return self._impl.set_items_embed(ids=ids)
4646

47-
def exposed_get_item_embed(self, id: int) -> bool:
48-
id = obtain(id)
49-
return self._impl.get_item_embed(id=id)
47+
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
48+
ids = obtain(ids)
49+
return self._impl.get_items_embed(ids=ids)
5050

5151

5252
def start_cache_manager(port: int, args, pipe_writer):

lightllm/server/httpserver/manager.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282

8383
self.enable_multimodal = enable_multimodal
8484
if self.enable_multimodal:
85-
self.cache_client = rpyc.connect("localhost", cache_port, onfig={"allow_pickle": True})
85+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
8686
self.send_to_visual = context.socket(zmq.PUSH)
8787
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
8888

@@ -160,9 +160,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
160160
token_nums.append(num_tokens)
161161
datas.append(data)
162162
items.append(img)
163-
# img.uuid = record["id"]
164-
# img.token_id = record["token_id"]
165-
# img.token_num = record["token_num"]
166163
for audio in multimodal_params.audios:
167164
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params)
168165
data = audio.read()
@@ -172,9 +169,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
172169
token_nums.append(num_tokens)
173170
datas.append(data)
174171
items.append(audio)
175-
# audio.uuid = record["id"]
176-
# audio.token_id = record["token_id"]
177-
# audio.token_num = record["token_num"]
178172
wait_time = 1
179173
while True:
180174
records = self.cache_client.root.alloc_batch(md5s, token_nums)
@@ -194,7 +188,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
194188
item.token_num = record["token_num"]
195189
if not ready:
196190
create_shm(get_shm_name_data(item.uuid), data)
197-
self.cache_client.root.set_items_data(item.uuid)
198191
uids_to_write.append(item.uuid)
199192
if uids_to_write:
200193
self.cache_client.root.set_items_data(uids_to_write)

lightllm/server/visualserver/manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535

3636
self.recv_from_httpserver = context.socket(zmq.PULL)
3737
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}")
38-
self.cache_client = rpyc.connect("localhost", cache_port)
38+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
3939
self.cache_port = cache_port
4040
self.waiting_reqs: List[GroupReqIndexes] = []
4141
self.model_weightdir = args.model_dir
@@ -121,11 +121,9 @@ async def loop_for_fwd(self):
121121
multimodal_params = group_req_indexes.multimodal_params
122122

123123
img_uuids = [img.uuid for img in multimodal_params.images]
124-
ready_flags = []
125-
for uuid in img_uuids:
126-
ready_flags.append(self.cache_client.root.get_items_embed(uuid))
124+
ready_image = self.cache_client.root.get_items_embed(img_uuids)
127125

128-
for img, ready in zip(multimodal_params.images, ready_flags):
126+
for img, ready in zip(multimodal_params.images, ready_image):
129127
if not ready:
130128
images_need_infer.append(img)
131129

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,20 @@ def exposed_encode(self, images: List[ImageItem]):
9494
images = obtain(images)
9595
all_img_embeds, uuids, valid_ids = self.forward(images)
9696
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
97+
9798
if self.tp_rank_id == 0:
98-
for i in range(len(uuids)):
99+
ready_flags = self.cache_client.root.get_items_embed(uuids)
100+
ids_to_set = []
101+
for i, ready in enumerate(ready_flags):
102+
if ready:
103+
continue
99104
uid = uuids[i]
100-
if not self.cache_client.root.get_item_embed(uid):
101-
start, end = valid_ids[i]
102-
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
103-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
104-
self.cache_client.root.set_item_embed(uuids[i])
105+
start, end = valid_ids[i]
106+
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
107+
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
108+
ids_to_set.append(uid)
109+
if ids_to_set:
110+
self.cache_client.root.set_items_embed(ids_to_set)
105111
return
106112

107113

0 commit comments

Comments
 (0)