Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions docs/CN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,6 @@ attention类型选择参数

多模态资源的缓存服务器容量,默认为 ``200``

.. option:: --cache_reserved_ratio

缓存服务器清理后的保留容量比例,默认为 ``0.5``

.. option:: --visual_infer_batch_size

每次推理批次中处理的图像数量,默认为 ``1``
Expand Down
4 changes: 0 additions & 4 deletions docs/EN/source/tutorial/api_server_args_zh.rst
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,6 @@ Multimodal Parameters

Cache server capacity for multimodal resources, default is ``200``

.. option:: --cache_reserved_ratio

Reserved capacity ratio after cache server cleanup, default is ``0.5``

.. option:: --visual_infer_batch_size

Number of images processed in each inference batch, default is ``1``
Expand Down
17 changes: 11 additions & 6 deletions lightllm/models/whisper/whisper_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from transformers.processing_utils import ProcessorMixin
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
from lightllm.server.multimodal_params import AudioItem

from rpyc.utils.classic import obtain

# tokenizer_class removed
class WhisperProcessor(ProcessorMixin):
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(self, kvargs):
self.sampling_rate = 16000
self.max_length = self.max_seconds * self.sampling_rate
self.cache_port = kvargs["cache_port"]
self.cache_client = rpyc.connect("localhost", self.cache_port)
self.cache_client = rpyc.connect("localhost", self.cache_port, config={"allow_pickle": True})
data_type = kvargs["data_type"]
if data_type in ["bf16", "bfloat16"]:
self.data_type = torch.bfloat16
Expand Down Expand Up @@ -190,8 +190,13 @@ def encode(self, audio_items: List[AudioItem]):
audio_lens_after_cnn = np.array(audio_lens_after_cnn, dtype=np.int32)
audio_token_num = (audio_lens_after_cnn - 2) // 2 + 1

for i in range(len(uuids)):
if not self.cache_client.root.get_item_embed(uuids[i]):
ready_audio = obtain(self.cache_client.root.get_items_embed(uuids))
ids_to_set = []
for i, ready in enumerate(ready_audio):
if not ready:
uid = uuids[i]
cur_embed_bytes = tensor2bytes(audios[i][: audio_token_num[i]])
create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes)
self.cache_client.root.set_item_embed(uuids[i])
create_shm(get_shm_name_embed(uid), cur_embed_bytes)
ids_to_set.append(uid)
if ids_to_set:
self.cache_client.root.set_items_embed(ids=ids_to_set)
3 changes: 0 additions & 3 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,6 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
)
parser.add_argument(
"--cache_reserved_ratio", type=float, default=0.5, help="cache server reserved capacity ratio after clear"
)
parser.add_argument(
"--data_type",
type=str,
Expand Down
10 changes: 7 additions & 3 deletions lightllm/server/audioserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from lightllm.server.multimodal_params import AudioItem
from .model_infer.model_rpc import start_model_process, AudioModelRpcClient
from lightllm.utils.graceful_utils import graceful_registry
from rpyc.utils.classic import obtain

logger = init_logger(__name__)

Expand All @@ -33,7 +34,7 @@ def __init__(

self.recv_from_visualserver = context.socket(zmq.PULL)
self.recv_from_visualserver.bind(f"{args.zmq_mode}127.0.0.1:{audio_port}")
self.cache_client = rpyc.connect("localhost", cache_port)
self.cache_client = rpyc.connect("localhost", cache_port, config={"allow_pickle": True})
self.cache_port = cache_port
self.waiting_reqs: List[GroupReqIndexes] = []
self.model_weightdir = args.model_dir
Expand Down Expand Up @@ -94,8 +95,11 @@ async def loop_for_fwd(self):

multimodal_params = group_req_indexes.multimodal_params

for audio in multimodal_params.audios:
if not self.cache_client.root.get_item_embed(audio.uuid):
audio_uuids = [audio.uuid for audio in multimodal_params.audios]
ready_audio = obtain(self.cache_client.root.get_items_embed(audio_uuids))

for audio, ready in zip(multimodal_params.audios, ready_audio):
if not ready:
audios_need_infer.append(audio)

if len(audios_need_infer) == self.infer_batch_size:
Expand Down
1 change: 0 additions & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ class StartArgs:
enable_decode_microbatch_overlap: bool = field(default=False)
enable_prefill_microbatch_overlap: bool = field(default=False)
cache_capacity: int = field(default=200)
cache_reserved_ratio: float = field(default=0.5)
data_type: Optional[str] = field(
default=None, metadata={"choices": ["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"]}
)
Expand Down
110 changes: 54 additions & 56 deletions lightllm/server/embed_cache/impl/naive_memory_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
import threading
import dataclasses
import requests
from ..interface import CacheManager, CacheManagerFactory
from typing import Union
from typing import Union, Optional
import torch
import time
from collections import deque
Expand All @@ -27,15 +26,12 @@ class Record(object):
token_num: int


@CacheManagerFactory.register("naive")
class InMemoryCache(CacheManager):
class InMemoryCache:
def __init__(self, args) -> None:
self.args = args
self._records = dict()
self._md5_to_record = dict()
self.capacity = max(1, args.cache_capacity)
self.reserved = max(0, int(self.capacity * args.cache_reserved_ratio))
self.reserved = min(self.reserved, self.capacity - 1)
self.occupied = 0
self.expired_secs = 60 * 60
self.lock = threading.Lock()
Expand Down Expand Up @@ -71,9 +67,9 @@ def _check_and_set_new_id_range(self, alloced_token_num):
time.sleep(3)
return

def _clear(self):
def _clear(self, free_max_count: int):
deleted = 0
max_delete = max(1, self.occupied - self.reserved)
max_delete = free_max_count
items = sorted(self._records.items(), key=lambda x: x[1].visittime)
t = time.time()
for id, record in items:
Expand All @@ -89,57 +85,59 @@ def _clear(self):
if deleted >= max_delete:
break

def alloc(self, md5sum: str, token_num: int) -> dict:
def alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
now = time.time()
with self.lock:
t = time.time()
# add new record
if md5sum not in self._md5_to_record:

# full, need to clear some unused items
if self.occupied >= self.capacity:
self._clear()
if self.occupied >= self.capacity:
return None

id = uuid.uuid1()
id = id.int
self._check_and_set_new_id_range(token_num)
record = Record(
id=id,
md5sum=md5sum,
ref=1,
data=False,
embed=False,
createtime=t,
visittime=t,
token_id=self.token_id_range_start,
token_num=token_num,
)
self.token_id_range_start += token_num
self._records[id] = record
self._md5_to_record[md5sum] = record
self.occupied += 1

# cache hit
else:
record = self._md5_to_record[md5sum]
record.visittime = t
record.ref += 1

return {"id": record.id, "token_id": record.token_id, "token_num": record.token_num}

def release(self, id: int) -> None:
new_md5s = [m for m in md5sum_list if m not in self._md5_to_record]
new_needed = len(set(new_md5s))

if self.occupied + new_needed > self.capacity:
self._clear(free_max_count=new_needed - (self.capacity - self.occupied))
if self.occupied + new_needed > self.capacity:
return None

results: list[dict] = []
for md5sum, token_num in zip(md5sum_list, token_num_list):
if md5sum in self._md5_to_record:
rec = self._md5_to_record[md5sum]
rec.visittime = now
rec.ref += 1
else:
uid_int = uuid.uuid1().int
self._check_and_set_new_id_range(token_num)
rec = Record(
id=uid_int,
md5sum=md5sum,
ref=1,
data=False,
embed=False,
createtime=now,
visittime=now,
token_id=self.token_id_range_start,
token_num=token_num,
)
self.token_id_range_start += token_num
self._records[uid_int] = rec
self._md5_to_record[md5sum] = rec
self.occupied += 1
results.append({"id": rec.id, "token_id": rec.token_id, "token_num": rec.token_num})
return results

def release(self, ids: list[int]) -> None:
with self.lock:
self._records[id].ref -= 1
for id_ in ids:
self._records[id_].ref -= 1

def set_item_data(self, id: int) -> None:
self._records[id].data = True
def set_items_data(self, ids: list[int]) -> None:
for id_ in ids:
self._records[id_].data = True

def get_item_data(self, id: int) -> bool:
return self._records[id].data
def get_items_data(self, ids: list[int]) -> list[Optional[bool]]:
return [self._records.get(id_).data if id_ in self._records else False for id_ in ids]

def set_item_embed(self, id: int) -> None:
self._records[id].embed = True
def set_items_embed(self, ids: list[int]) -> None:
for id_ in ids:
self._records[id_].embed = True

def get_item_embed(self, id: int) -> bool:
return self._records[id].embed
def get_items_embed(self, ids: list[int]) -> list[Optional[bool]]:
return [self._records.get(id_).embed if id_ in self._records else False for id_ in ids]
48 changes: 0 additions & 48 deletions lightllm/server/embed_cache/interface.py

This file was deleted.

49 changes: 23 additions & 26 deletions lightllm/server/embed_cache/manager.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import rpyc
import uuid
import inspect
from typing import Union
from typing import Union, Optional
from lightllm.utils.graceful_utils import graceful_registry
from .interface import CacheManager
from lightllm.server.embed_cache.impl.naive_memory_cache import InMemoryCache
from rpyc.utils.classic import obtain


class CacheServer(rpyc.Service):
def __init__(self, manager_impl: CacheManager) -> None:
def __init__(self, manager_impl: InMemoryCache) -> None:
super().__init__()
self._impl = manager_impl

Expand All @@ -22,41 +22,38 @@ def on_disconnect(self, conn):
# (to finalize the service, if needed)
pass

def exposed_alloc(self, md5sum: str, token_num: int) -> dict:
md5sum = obtain(md5sum)
token_num = obtain(token_num)
record = self._impl.alloc(md5sum, token_num)
def exposed_alloc(self, md5sum_list: list[str], token_num_list: list[int]) -> Optional[list[dict]]:
md5sum_list = obtain(md5sum_list)
token_num_list = obtain(token_num_list)
record = self._impl.alloc(md5sum_list, token_num_list)
return record

def exposed_release(self, id: int) -> None:
id = obtain(id)
return self._impl.release(id)
def exposed_release(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.release(ids)

def exposed_set_item_data(self, id: int) -> None:
id = obtain(id)
return self._impl.set_item_data(id=id)
def exposed_set_items_data(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.set_items_data(ids)

def exposed_get_item_data(self, id: int) -> bool:
id = obtain(id)
return self._impl.get_item_data(id=id)
def exposed_get_items_data(self, ids: list[int]) -> list[bool]:
ids = obtain(ids)
return self._impl.get_items_data(ids)

def exposed_set_item_embed(self, id: int) -> None:
id = obtain(id)
return self._impl.set_item_embed(id=id)
def exposed_set_items_embed(self, ids: list[int]) -> None:
ids = obtain(ids)
return self._impl.set_items_embed(ids)

def exposed_get_item_embed(self, id: int) -> bool:
id = obtain(id)
return self._impl.get_item_embed(id=id)
def exposed_get_items_embed(self, ids: list[int]) -> list[bool]:
ids = obtain(ids)
return self._impl.get_items_embed(ids)


def start_cache_manager(port: int, args, pipe_writer):
# 注册graceful 退出的处理
graceful_registry(inspect.currentframe().f_code.co_name)

from .interface import CacheManagerFactory

manager_cls = CacheManagerFactory.get_impl("naive")
manager = manager_cls(args)
manager = InMemoryCache(args)
service = CacheServer(manager)
from rpyc.utils.server import ThreadedServer

Expand Down
Loading