Skip to content

Commit c9f128c

Browse files
committed
Implement mmalloc optimization
1 parent 81b9ecb commit c9f128c

File tree

8 files changed

+243
-45
lines changed

8 files changed

+243
-45
lines changed

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
repos:
2-
- repo: https://github.com/psf/black
2+
- repo: git@github.com:psf/black.git
33
rev: 21.12b0
44
hooks:
55
- id: black
66
language_version: python3
77
args: [--line-length=120]
88
additional_dependencies: ['click==8.0.4']
9-
- repo: https://github.com/pycqa/flake8
9+
- repo: git@github.com:pycqa/flake8.git
1010
rev: 3.9.0
1111
hooks:
1212
- id: flake8

lightllm/server/api_cli.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,10 @@ def make_argument_parser() -> argparse.ArgumentParser:
296296
parser.add_argument(
297297
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
298298
)
299+
parser.add_argument(
300+
"--enable_concurrent_alloc", action="store_true", help="alloc multimodal resources in threadpool to save time"
301+
)
302+
parser.add_argument("--concurrent_alloc_workers", type=int, default=4, help="max concurrent threadpool workers")
299303
parser.add_argument(
300304
"--data_type",
301305
type=str,

lightllm/server/audioserver/manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,9 @@ async def loop_for_fwd(self):
9696
multimodal_params = group_req_indexes.multimodal_params
9797

9898
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
99-
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))
99+
audio_uuids = pickle.dumps(audio_uuids)
100+
ready_audio = self.cache_client.root.get_items_embed_v2(audio_uuids)
101+
ready_audio = pickle.loads(ready_audio)
100102

101103
for audio, ready in zip(multimodal_params.audios, ready_audio):
102104
if not ready:

lightllm/server/embed_cache/manager.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from lightllm.utils.graceful_utils import graceful_registry
66
from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache
77
from rpyc.utils.classic import obtain
8+
import pickle
89

910

1011
class CacheServer(rpyc.Service):
@@ -48,6 +49,47 @@ def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
4849
ids = obtain(ids)
4950
return self._impl.get_items_embed(ids)
5051

52+
def exposed_alloc_v2(self, batch_md5_token_nums: bytes) -> bytes:
53+
"""
54+
batch_md5_token_nums: pickle.dumps([(md5sum, token_num), ...])
55+
返回: pickle.dumps(records)
56+
"""
57+
batch_requests = pickle.loads(batch_md5_token_nums)
58+
md5sum_list = [obtain(md5) for md5, num in batch_requests]
59+
token_num_list = [obtain(num) for md5, num in batch_requests]
60+
record = self._impl.alloc(md5sum_list, token_num_list)
61+
return pickle.dumps(record)
62+
63+
def exposed_release_v2(self, ids_blob: bytes) -> None:
64+
ids = pickle.loads(ids_blob)
65+
ids = [obtain(id) for id in ids]
66+
return self._impl.release(ids)
67+
68+
def exposed_set_items_data_v2(self, ids_blob: bytes) -> bytes:
69+
ids = pickle.loads(ids_blob)
70+
ids = [obtain(id) for id in ids]
71+
status_list = self._impl.set_items_data(ids)
72+
return pickle.dumps(status_list)
73+
74+
def exposed_get_items_data_v2(self, ids_blob: bytes) -> bytes:
75+
ids = pickle.loads(ids_blob)
76+
ids = [obtain(id) for id in ids]
77+
status_list = self._impl.get_items_data(ids)
78+
return pickle.dumps(status_list)
79+
80+
def exposed_set_items_embed_v2(self, ids_blob: bytes) -> None:
81+
82+
ids = pickle.loads(ids_blob)
83+
ids = [obtain(id) for id in ids]
84+
status_list = self._impl.set_items_embed(ids)
85+
return pickle.dumps(status_list)
86+
87+
def exposed_get_items_embed_v2(self, ids_blob: bytes) -> bytes:
88+
ids = pickle.loads(ids_blob)
89+
ids = [obtain(id) for id in ids]
90+
status_list = self._impl.get_items_embed(ids)
91+
return pickle.dumps(status_list)
92+
5193

5294
def start_cache_manager(port: int, args, pipe_writer):
5395
# 注册graceful 退出的处理

lightllm/server/httpserver/manager.py

Lines changed: 139 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import datetime
1111
import pickle
1212
from frozendict import frozendict
13+
import concurrent.futures
1314

1415
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
1516
from typing import Union, List, Tuple, Dict, Optional
@@ -32,6 +33,7 @@
3233
from lightllm.utils.statics_utils import MovingAverage
3334
from lightllm.utils.config_utils import get_vocab_size
3435
from lightllm.utils.envs_utils import get_unique_server_name
36+
from lightllm.utils.infer_utils import calculate_cpu_time_async, calculate_cpu_time_sync
3537
from rpyc.utils.classic import obtain
3638

3739
logger = init_logger(__name__)
@@ -112,13 +114,19 @@ def __init__(
112114
# If the timemark is not updated for a pre-set time, a prob request will be sent to the backend.
113115
self.latest_success_infer_time_mark = SharedInt(f"{get_unique_server_name()}_latest_success_infer_time_mark")
114116
self.latest_success_infer_time_mark.set_value(int(time.time()))
117+
118+
# 线程池用于创建multimodal resource alloc
119+
self.enable_concurrent_alloc = self.args.enable_concurrent_alloc
120+
self.max_concurrent = self.args.concurrent_alloc_workers * 48
121+
if self.enable_concurrent_alloc:
122+
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=self.args.concurrent_alloc_workers)
115123
return
116124

117125
async def _alloc_resource(self, items, md5sums, token_nums, datas):
118-
119126
while True:
127+
t1 = time.time()
120128
records = obtain(self.cache_client.root.alloc(md5sums, token_nums))
121-
129+
logger.info(f"cache manager batch alloc time: {(time.time() - t1)*1000} ms")
122130
if records is None:
123131
await asyncio.sleep(0.1)
124132
continue
@@ -142,37 +150,139 @@ async def _alloc_resource(self, items, md5sums, token_nums, datas):
142150
self.cache_client.root.set_items_data(update_data_ids)
143151
return
144152

153+
async def _alloc_resource_v2(self, items, md5sums, token_nums, datas):
154+
batch_requests = [(md5sum, token_num) for md5sum, token_num in zip(md5sums, token_nums)]
155+
while True:
156+
t1 = time.time()
157+
req_blob = pickle.dumps(batch_requests)
158+
res_blob = self.cache_client.root.alloc_v2(req_blob)
159+
records = pickle.loads(res_blob)
160+
logger.info(f"cache manager batch alloc time: {(time.time() - t1)*1000} ms")
161+
if records is None:
162+
await asyncio.sleep(0.1)
163+
continue
164+
165+
uid_list = []
166+
for item, rec in zip(items, records):
167+
item.uuid = rec["id"]
168+
item.token_id = rec["token_id"]
169+
item.token_num = rec["token_num"]
170+
uid_list.append(rec["id"])
171+
172+
uid_blob = pickle.dumps(uid_list)
173+
ready_flags = self.cache_client.root.get_items_data_v2(uid_blob)
174+
ready_flags = pickle.loads(ready_flags)
175+
176+
max_concurrent_shm = min(len(items), self.max_concurrent) # 限制最大并发
177+
semaphore = asyncio.Semaphore(max_concurrent_shm)
178+
179+
async def create_shm_with_limit(uid, data):
180+
async with semaphore:
181+
loop = asyncio.get_event_loop()
182+
return await loop.run_in_executor(self.executor, create_shm, get_shm_name_data(uid), data)
183+
184+
update_data_ids = []
185+
shm_tasks = []
186+
for uid, ready, data in zip(uid_list, ready_flags, datas):
187+
if not ready:
188+
task = create_shm_with_limit(uid, data)
189+
shm_tasks.append(task)
190+
update_data_ids.append(uid)
191+
192+
if len(shm_tasks):
193+
t_shm = time.time()
194+
await asyncio.gather(*shm_tasks)
195+
logger.info(f"concurrent create shm time: {(time.time() - t_shm)*1000} ms")
196+
197+
if update_data_ids:
198+
update_dataids_blob = pickle.dumps(update_data_ids)
199+
self.cache_client.root.set_items_data_v2(update_dataids_blob)
200+
return
201+
202+
@calculate_cpu_time_async(show=True)
145203
async def _alloc_multimodal_resources(self, multimodal_params: MultimodalParams, sampling_params: SamplingParams):
146204
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
147205
if self.pd_mode.is_P_or_NORMAL():
148206
# 这里的锁是为了 防止多个含有多张图片的请求 同时申请的record数量 大于cache_capacity,从而造成死锁的问题。
149207
# 如果不加任何锁,假如请求1和请求2都有6张图片,而cache_capacity为10,
150208
# 那么如果某一时刻shm中存在请求1的5张图和请求2的5张图,将会资源竞争产生死锁。
151209
async with self._resource_lock:
152-
items, md5sums, tokens_nums, datas = [], [], [], []
153-
for img in multimodal_params.images:
154-
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params)
155-
data = img.read()
156-
# must after init_imageitem_extral_params
157-
token_num = self.tokenizer.get_image_token_length(img)
158-
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
159-
md5sums.append(md5sum)
160-
tokens_nums.append(token_num)
161-
datas.append(data)
162-
items.append(img)
163-
for audio in multimodal_params.audios:
164-
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params)
165-
data = audio.read()
166-
token_num = self.tokenizer.get_audio_token_length(audio)
167-
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params)))
168-
md5sums.append(md5sum)
169-
tokens_nums.append(token_num)
170-
datas.append(data)
171-
items.append(audio)
172-
173-
await self._alloc_resource(items, md5sums, tokens_nums, datas)
210+
if self.enable_concurrent_alloc:
211+
await self._alloc_multimodal_resources_v2(multimodal_params, sampling_params)
212+
else:
213+
await self._alloc_multimodal_resources_v1(multimodal_params, sampling_params)
214+
174215
return
175216

217+
async def _alloc_multimodal_resources_v1(
218+
self, multimodal_params: MultimodalParams, sampling_params: SamplingParams
219+
):
220+
items, md5sums, tokens_nums, datas = [], [], [], []
221+
for img in multimodal_params.images:
222+
self.tokenizer.init_imageitem_extral_params(img, multimodal_params, sampling_params)
223+
data = img.read()
224+
# must after init_imageitem_extral_params
225+
token_num = self.tokenizer.get_image_token_length(img)
226+
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(img.extra_params)))
227+
md5sums.append(md5sum)
228+
tokens_nums.append(token_num)
229+
datas.append(data)
230+
items.append(img)
231+
for audio in multimodal_params.audios:
232+
self.tokenizer.init_audioitem_extral_params(audio, multimodal_params, sampling_params)
233+
data = audio.read()
234+
token_num = self.tokenizer.get_audio_token_length(audio)
235+
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(audio.extra_params)))
236+
md5sums.append(md5sum)
237+
tokens_nums.append(token_num)
238+
datas.append(data)
239+
items.append(audio)
240+
241+
await self._alloc_resource(items, md5sums, tokens_nums, datas)
242+
243+
async def _alloc_multimodal_resources_v2(
244+
self, multimodal_params: MultimodalParams, sampling_params: SamplingParams
245+
):
246+
all_items = multimodal_params.images + multimodal_params.audios
247+
if not all_items:
248+
return
249+
loop = asyncio.get_event_loop()
250+
251+
def _process_item(item, multimodal_params, sampling_params):
252+
"""初始化item参数、读取数据并计算MD5"""
253+
if isinstance(item, ImageItem): # 图片
254+
self.tokenizer.init_imageitem_extral_params(item, multimodal_params, sampling_params)
255+
elif isinstance(item, AudioItem):
256+
self.tokenizer.init_audioitem_extral_params(item, multimodal_params, sampling_params)
257+
258+
data = item.read()
259+
md5sum = hashlib.md5(data).hexdigest() + "_" + str(hash(frozendict(item.extra_params)))
260+
return data, md5sum
261+
262+
chunk_size = self.max_concurrent # 可以根据需要调整
263+
for i in range(0, len(all_items), chunk_size):
264+
chunk = all_items[i : i + chunk_size]
265+
266+
# 并发处理chunk内的所有item
267+
process_tasks = [
268+
loop.run_in_executor(self.executor, _process_item, item, multimodal_params, sampling_params)
269+
for item in chunk
270+
]
271+
chunk_results = await asyncio.gather(*process_tasks)
272+
chunk_items, chunk_md5sums, chunk_tokens_nums, chunk_datas = [], [], [], []
273+
for j, item in enumerate(chunk):
274+
data, md5sum = chunk_results[j]
275+
if isinstance(item, ImageItem):
276+
token_num = self.tokenizer.get_image_token_length(item)
277+
elif isinstance(item, AudioItem):
278+
token_num = self.tokenizer.get_audio_token_length(item)
279+
chunk_items.append(item)
280+
chunk_md5sums.append(md5sum)
281+
chunk_tokens_nums.append(token_num)
282+
chunk_datas.append(data)
283+
284+
await self._alloc_resource_v2(chunk_items, chunk_md5sums, chunk_tokens_nums, chunk_datas)
285+
176286
async def _release_multimodal_resources(self, multimodal_params: MultimodalParams):
177287
# 只有 P 和 NORMAL 节点需要真的管理多模态资源
178288
if self.pd_mode.is_P_or_NORMAL():
@@ -193,7 +303,11 @@ async def _release_multimodal_resources(self, multimodal_params: MultimodalParam
193303
audio.token_id = None
194304
audio.token_num = None
195305
if ids_to_release:
196-
self.cache_client.root.release(ids_to_release)
306+
if self.enable_concurrent_alloc:
307+
release_id_blobs = pickle.dumps(ids_to_release)
308+
self.cache_client.root.release_v2(release_id_blobs)
309+
else:
310+
self.cache_client.root.release(ids_to_release)
197311
return
198312

199313
def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwargs=None):
@@ -341,7 +455,6 @@ async def generate(
341455
return
342456

343457
async def _log_req_header(self, request_headers, group_request_id: int):
344-
345458
x_request_id = request_headers.get("X-Request-Id", "")
346459
x_session_id = request_headers.get("X-Session-Id", "")
347460

@@ -436,7 +549,6 @@ async def transfer_to_next_module(
436549
self,
437550
group_req_objs: Optional[GroupReqObjs] = None,
438551
):
439-
440552
if self.pd_mode == NodeRole.P:
441553
if self.enable_multimodal:
442554
self.send_to_visual.send_pyobj(
@@ -483,7 +595,6 @@ async def _wait_to_token_package(
483595
req_status: "ReqStatus",
484596
request: Request,
485597
):
486-
487598
event = req_status.event
488599
unfinished_count = sampling_params.best_of
489600
out_token_counter = 0
@@ -589,7 +700,6 @@ async def recycle_resource_loop(self):
589700
pre_time_mark = time.time()
590701

591702
while True:
592-
593703
try:
594704
await asyncio.wait_for(self.recycle_event.wait(), timeout=0.02)
595705
except asyncio.TimeoutError:
@@ -660,7 +770,6 @@ async def handle_loop(self):
660770

661771
for _ in range(read_token_count):
662772
if not req.out_tokens_queue.is_empty():
663-
664773
text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
665774
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
666775
metadata = {

lightllm/server/visualserver/manager.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ async def loop_for_fwd(self):
122122
multimodal_params = group_req_indexes.multimodal_params
123123

124124
img_uuids = [img.uuid for img in multimodal_params.images]
125-
ready_image = obtain(self.cache_client.root.get_items_embed(img_uuids))
125+
img_uuids = pickle.dumps(img_uuids)
126+
ready_image = self.cache_client.root.get_items_embed_v2(img_uuids)
127+
ready_image = pickle.loads(ready_image)
126128

127129
for img, ready in zip(multimodal_params.images, ready_image):
128130
if not ready:

0 commit comments

Comments
 (0)