Skip to content

Commit aa1c586

Browse files
SangChengCsangchengmeng
authored andcommitted
[fix]fix rpyc in multimodal process
1 parent 780a57f commit aa1c586

File tree

8 files changed

+62
-128
lines changed

8 files changed

+62
-128
lines changed

lightllm/models/whisper/whisper_audio.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,14 @@ def encode(self, audio_items: List[AudioItem]):
190190
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
191191
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1
192192

193-
for i in range(len(uuids)):
194-
if not self.cache_client.root.get_item_embed(uuids[i]):
195-
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
196-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
197-
self.cache_client.root.set_item_embed(uuids[i])
193+
ready_audio = self.cache_client.root.get_items_data(uuids)
194+
ids_to_set = []
195+
for i, ready in enumerate(ready_audio):
196+
if ready:
197+
continue
198+
uid = uuids[i]
199+
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
200+
create_shm(get_shm_name_data(uid), cur_embed_bytes)
201+
ids_to_set.append(uid)
202+
if ids_to_set:
203+
self.cache_client.root.set_items_data(ids=ids_to_set)

lightllm/server/audioserver/manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,11 @@ async def loop_for_fwd(self):
9494

9595
multimodal_params = group_req_indexes.multimodal_params
9696

97-
for audio in multimodal_params.audios:
98-
if not self.cache_client.root.get_item_embed(audio.uuid):
97+
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
98+
ready_audio = self.cache_client.root.get_items_embed(audio_uuids)
99+
100+
for audio, ready in zip(multimodal_params.audios, ready_audio):
101+
if not ready:
99102
audios_need_infer.append(audio)
100103

101104
if len(audios_need_infer) == self.infer_batch_size:

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 12 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -89,52 +89,12 @@ def _clear(self):
8989
if deleted >= max_delete:
9090
break
9191

92-
# def alloc(self, md5sum: str, token_num: int) -> dict:
93-
# with self.lock:
94-
# t = time.time()
95-
# # add new record
96-
# if md5sum not in self._md5_to_record:
97-
98-
# # full, need to clear some unused items
99-
# if self.occupied >= self.capacity:
100-
# self._clear()
101-
# if self.occupied >= self.capacity:
102-
# return None
103-
104-
# id = uuid.uuid1()
105-
# id = id.int
106-
# self._check_and_set_new_id_range(token_num)
107-
# record = Record(
108-
# id=id,
109-
# md5sum=md5sum,
110-
# ref=1,
111-
# data=False,
112-
# embed=False,
113-
# createtime=t,
114-
# visittime=t,
115-
# token_id=self.token_id_range_start,
116-
# token_num=token_num,
117-
# )
118-
# self.token_id_range_start += token_num
119-
# self._records[id] = record
120-
# self._md5_to_record[md5sum] = record
121-
# self.occupied += 1
122-
123-
# # cache hit
124-
# else:
125-
# record = self._md5_to_record[md5sum]
126-
# record.visittime = t
127-
# record.ref += 1
128-
129-
# return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}
130-
13192
def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[dict]:
13293
results = []
13394
with self.lock:
13495
for md5, tnum in zip(md5_list, token_num_list):
13596
t = time.time()
13697
if md5 not in self._md5_to_record:
137-
# 若不存在则分配新记录(与alloc逻辑相同)
13898
if self.occupied >= self.capacity:
13999
self._clear()
140100
if self.occupied >= self.capacity:
@@ -158,34 +118,27 @@ def alloc_batch(self, md5_list: list[str], token_num_list: list[int]) -> list[di
158118
self._md5_to_record[md5] = record
159119
self.occupied += 1
160120
else:
161-
# 缓存命中,更新引用计数和访问时间
162121
record = self._md5_to_record[md5]
163122
record.visittime = t
164123
record.ref += 1
165124
results.append({"id": record.id, "token_id": record.token_id, "token_num": record.token_num})
166125
return results
167126

168-
def release(self, id: int) -> None:
127+
def release(self, ids: list[int]) -> None:
169128
with self.lock:
170-
self._records[id].ref -= 1
171-
172-
# def set_item_data(self, id: int) -> None:
173-
# self._records[id].data = True
129+
for id in ids:
130+
self._records[id].ref -= 1
174131

175-
# def get_item_data(self, id: int) -> bool:
176-
# return self._records[id].data
132+
def set_items_data(self, ids: list[int]) -> None:
133+
for id in ids:
134+
self._records[id].data = True
177135

178136
def get_items_data(self, ids: list[int]) -> list[bool]:
179-
with self.lock:
180-
return [self._records.get(i).data if i in self._records else False for i in ids]
181-
182-
def set_items_data(self, ids: list[int]) -> None:
183-
with self.lock:
184-
for i in ids:
185-
self._records[i].data = True
137+
return [self._records.get(i).data if i in self._records else False for i in ids]
186138

187-
def set_item_embed(self, id: int) -> None:
188-
self._records[id].embed = True
139+
def set_items_embed(self, ids: list[int]) -> None:
140+
for id in ids:
141+
self._records[id].embed = True
189142

190-
def get_item_embed(self, id: int) -> bool:
191-
return self._records[id].embed
143+
def get_items_embed(self, ids: list[int]) -> list[bool]:
144+
return [self._records.get(i).embed if i in self._records else False for i in ids]

lightllm/server/embed_cache/interface.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ def __init__(self) -> None:
1010
def alloc_batch(self, md5sum_list: list[str], token_num_list: list[int]) -> list[dict]:
1111
pass
1212

13-
def release(self, id: int) -> None:
13+
def release(self, ids: list[int]) -> None:
1414
pass
1515

1616
def set_items_data(self, ids: list[int]) -> None:
@@ -19,10 +19,10 @@ def set_items_data(self, ids: list[int]) -> None:
1919
def get_items_data(self, ids: list[int]) -> list[bool]:
2020
pass
2121

22-
def set_item_embed(self, id: int) -> None:
22+
def set_items_embed(self, ids: list[int]) -> None:
2323
pass
2424

25-
def get_item_embed(self, id: int) -> bool:
25+
def get_items_embed(self, ids: list[int]) -> list[bool]:
2626
pass
2727

2828

lightllm/server/embed_cache/manager.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ def on_disconnect(self, conn):
2525
def exposed_alloc_batch(self, md5sum_list: list[str], token_num_list: list[int]) -> dict:
2626
md5sum_list = obtain(md5sum_list)
2727
token_num_list = obtain(token_num_list)
28-
record = self._impl.alloc(md5sum_list, token_num_list)
28+
record = self._impl.alloc_batch(md5sum_list, token_num_list)
2929
return record
3030

31-
def exposed_release(self, id: int) -> None:
32-
id = obtain(id)
33-
return self._impl.release(id)
31+
def exposed_release(self, ids: list[int]) -> None:
32+
ids = obtain(ids)
33+
return self._impl.release(ids)
3434

3535
def exposed_set_items_data(self, ids: list[int]) -> None:
3636
ids = obtain(ids)
@@ -40,13 +40,13 @@ def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
4040
ids = obtain(ids)
4141
return self._impl.get_items_data(ids=ids)
4242

43-
def exposed_set_item_embed(self, id: int) -> None:
44-
id = obtain(id)
45-
return self._impl.set_item_embed(id=id)
43+
def exposed_set_items_embed(self, ids: list[int]) -> None:
44+
ids = obtain(ids)
45+
return self._impl.set_items_embed(ids=ids)
4646

47-
def exposed_get_item_embed(self, id: int) -> bool:
48-
id = obtain(id)
49-
return self._impl.get_item_embed(id=id)
47+
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
48+
ids = obtain(ids)
49+
return self._impl.get_items_embed(ids=ids)
5050

5151

5252
def start_cache_manager(port: int, args, pipe_writer):

lightllm/server/httpserver/manager.py

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def __init__(
8282

8383
self.enable_multimodal = enable_multimodal
8484
if self.enable_multimodal:
85-
self.cache_client = rpyc.connect("localhost", cache_port, onfig={"allow_pickle": True})
85+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
8686
self.send_to_visual = context.socket(zmq.PUSH)
8787
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
8888

@@ -114,34 +114,6 @@ def __init__(
114114
self.latest_success_infer_time_mark.set_value(int(time.time()))
115115
return
116116

117-
# connect cache server, calculate md5, alloc resource, return uuid
118-
async def _alloc_resource(self, item: Union[ImageItem, AudioItem]):
119-
if isinstance(item, ImageItem):
120-
data = item.read()
121-
# must after init_imageitem_extral_params
122-
num_tokens = self.tokenizer.get_image_token_length(item)
123-
elif isinstance(item, AudioItem):
124-
data = item.read()
125-
num_tokens = self.tokenizer.get_audio_token_length(item)
126-
else:
127-
raise ValueError(f"unexpected item type {type(item)}")
128-
129-
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(item.extra_params)))
130-
wait_time = 1
131-
while True:
132-
record = self.cache_client.root.alloc(md5sum, num_tokens)
133-
# hit or new
134-
if record:
135-
uid = record["id"]
136-
if not self.cache_client.root.get_item_data(uid):
137-
create_shm(get_shm_name_data(uid), data)
138-
self.cache_client.root.set_item_data(uid)
139-
return record
140-
# cache full
141-
else:
142-
await asyncio.sleep(wait_time)
143-
wait_time = min(wait_time + 2, 9)
144-
145117
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams):
146118
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
147119
if self.pd_mode.is_P_or_NORMAL():
@@ -160,9 +132,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
160132
token_nums.append(num_tokens)
161133
datas.append(data)
162134
items.append(img)
163-
# img.uuid = record["id"]
164-
# img.token_id = record["token_id"]
165-
# img.token_num = record["token_num"]
166135
for audio in multimodal_params.audios:
167136
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params)
168137
data = audio.read()
@@ -172,9 +141,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
172141
token_nums.append(num_tokens)
173142
datas.append(data)
174143
items.append(audio)
175-
# audio.uuid = record["id"]
176-
# audio.token_id = record["token_id"]
177-
# audio.token_num = record["token_num"]
178144
wait_time = 1
179145
while True:
180146
records = self.cache_client.root.alloc_batch(md5s, token_nums)
@@ -194,7 +160,6 @@ async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams,
194160
item.token_num = record["token_num"]
195161
if not ready:
196162
create_shm(get_shm_name_data(item.uuid), data)
197-
self.cache_client.root.set_items_data(item.uuid)
198163
uids_to_write.append(item.uuid)
199164
if uids_to_write:
200165
self.cache_client.root.set_items_data(uids_to_write)
@@ -203,20 +168,23 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
203168
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
204169
if self.pd_mode.is_P_or_NORMAL():
205170
if multimodal_params is not None:
171+
ids_to_release = []
206172
for img in multimodal_params.images:
207173
if img.uuid is not None:
208-
self.cache_client.root.release(img.uuid)
174+
ids_to_release.append(img.uuid)
209175
# 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常
210176
img.uuid = None
211177
img.token_id = None
212178
img.token_num = None
213179
for audio in multimodal_params.audios:
214180
if audio.uuid is not None:
215-
self.cache_client.root.release(audio.uuid)
181+
ids_to_release.append(audio.uuid)
216182
# 将 uuid 等 赋值为 None, 防止因为abort等异常情况造成重复释放异常
217183
audio.uuid = None
218184
audio.token_id = None
219185
audio.token_num = None
186+
if ids_to_release:
187+
self.cache_client.root.release(ids_to_release)
220188
return
221189

222190
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):

lightllm/server/visualserver/manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def __init__(
3535

3636
self.recv_from_httpserver = context.socket(zmq.PULL)
3737
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{visual_port}")
38-
self.cache_client = rpyc.connect("localhost", cache_port)
38+
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
3939
self.cache_port = cache_port
4040
self.waiting_reqs: List[GroupReqIndexes] = []
4141
self.model_weightdir = args.model_dir
@@ -121,11 +121,9 @@ async def loop_for_fwd(self):
121121
multimodal_params = group_req_indexes.multimodal_params
122122

123123
img_uuids = [img.uuid for img in multimodal_params.images]
124-
ready_flags = []
125-
for uuid in img_uuids:
126-
ready_flags.append(self.cache_client.root.get_items_embed(uuid))
124+
ready_image = self.cache_client.root.get_items_embed(img_uuids)
127125

128-
for img, ready in zip(multimodal_params.images, ready_flags):
126+
for img, ready in zip(multimodal_params.images, ready_image):
129127
if not ready:
130128
images_need_infer.append(img)
131129

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,20 @@ def exposed_encode(self, images: List[ImageItem]):
9494
images = obtain(images)
9595
all_img_embeds, uuids, valid_ids = self.forward(images)
9696
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
97+
9798
if self.tp_rank_id == 0:
98-
for i in range(len(uuids)):
99+
ready_flags = self.cache_client.root.get_items_embed(uuids)
100+
ids_to_set = []
101+
for i, ready in enumerate(ready_flags):
102+
if ready:
103+
continue
99104
uid = uuids[i]
100-
if not self.cache_client.root.get_item_embed(uid):
101-
start, end = valid_ids[i]
102-
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
103-
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
104-
self.cache_client.root.set_item_embed(uuids[i])
105+
start, end = valid_ids[i]
106+
cur_embed_bytes = tensor2bytes(all_img_embeds[start:end])
107+
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
108+
ids_to_set.append(uid)
109+
if ids_to_set:
110+
self.cache_client.root.set_items_embed(ids_to_set)
105111
return
106112

107113

0 commit comments

Comments
 (0)