Skip to content

Commit 2aec5a1

Browse files
committed
fix
1 parent 60fc7f5 commit 2aec5a1

File tree

4 files changed

+92
-15
lines changed

4 files changed

+92
-15
lines changed
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
from typing import Optional
6+
7+
8+
@triton.jit
9+
def _offload_embed_tensor_to_cache(
10+
embed_tensor_ptr,
11+
gpu_stride0,
12+
gpu_stride1,
13+
gpu_stride2,
14+
cache_tensor_ptr,
15+
cpu_stride0,
16+
cpu_stride1,
17+
cpu_stride2,
18+
start_index_in_cache,
19+
layer_num,
20+
hidden_size,
21+
BLOCK: tl.constexpr,
22+
):
23+
token_index = tl.program_id(0).to(tl.int64)
24+
dest_index = (start_index_in_cache + token_index).to(tl.int64)
25+
26+
for layer_index in range(layer_num):
27+
layer_index = layer_index.to(tl.int64)
28+
for block_index in range(tl.cdiv(hidden_size, BLOCK)):
29+
off = block_index * BLOCK + tl.arange(0, BLOCK)
30+
mask = off < hidden_size
31+
gpu_data = tl.load(
32+
embed_tensor_ptr + token_index * gpu_stride0 + layer_index * gpu_stride1 + off * gpu_stride2, mask=mask
33+
)
34+
tl.store(
35+
cache_tensor_ptr + dest_index * cpu_stride0 + layer_index * cpu_stride1 + off * cpu_stride2,
36+
gpu_data,
37+
mask=mask,
38+
)
39+
40+
return
41+
42+
43+
@torch.no_grad()
44+
def offload_embed_tensor_to_cache(
45+
embed_tensor: torch.Tensor,
46+
cache_tensor: torch.Tensor,
47+
start_index_in_cache: int,
48+
):
49+
if len(embed_tensor.shape) == 2:
50+
embed_tensor = embed_tensor.reshape(embed_tensor.shape[0], 1, embed_tensor.shape[1])
51+
52+
token_num = embed_tensor.shape[0]
53+
grid = (token_num,)
54+
55+
_offload_embed_tensor_to_cache[grid](
56+
embed_tensor_ptr=embed_tensor,
57+
gpu_stride0=embed_tensor.stride(0),
58+
gpu_stride1=embed_tensor.stride(1),
59+
gpu_stride2=embed_tensor.stride(2),
60+
cache_tensor_ptr=cache_tensor,
61+
cpu_stride0=cache_tensor.stride(0),
62+
cpu_stride1=cache_tensor.stride(1),
63+
cpu_stride2=cache_tensor.stride(2),
64+
start_index_in_cache=start_index_in_cache,
65+
layer_num=embed_tensor.shape[1],
66+
hidden_size=embed_tensor.shape[2],
67+
BLOCK=256,
68+
num_warps=4,
69+
num_stages=1,
70+
)
71+
return

lightllm/server/embed_cache/embed_cache_client.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ class CpuEmbedCacheClient(object):
2020
This class is responsible for handling cpu kv cache meta data.
2121
"""
2222

23-
def __init__(self, only_create_meta_data: bool, init_shm_data: bool):
23+
def __init__(self, create_meta_data: bool, init_shm_data: bool):
2424
self.args = get_env_start_args()
2525
# to do here need calcu from from settings.
2626
self.embed_cache_tensor_meta = calcu_embed_cache_meta()
2727
self.token_num: int = self.embed_cache_tensor_meta.token_num
2828

29-
if not only_create_meta_data:
29+
if create_meta_data:
3030
self.token_index_manager = MemoryManager(total_size=self.token_num)
3131
else:
3232
if init_shm_data:
3333
self._create_shm_embed_kv_cache()
3434
else:
35-
self._create_shm_embed_kv_cache()
35+
self._attach_shm_cpu_embed_cache()
3636
return
3737

3838
def alloc_indexes(self, token_num: int) -> Optional["MemoryBlock"]:
@@ -42,6 +42,15 @@ def release_indexes(self, block: "MemoryBlock"):
4242
self.token_index_manager.release(block)
4343
return
4444

45+
def copy_to_cache(self, embed_tensor: torch.Tensor, start_index_in_cache: int):
46+
from .copy_to_cache import offload_embed_tensor_to_cache
47+
48+
offload_embed_tensor_to_cache(
49+
embed_tensor=embed_tensor,
50+
cache_tensor=self.cpu_embed_cache_tensor,
51+
start_index_in_cache=start_index_in_cache,
52+
)
53+
4554
def _create_shm_embed_kv_cache(self):
4655
shm_ptr = create_shm_embed_cache_ptr()
4756
numpy_array = np.frombuffer(

lightllm/server/embed_cache/impl/naive_memory_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, args) -> None:
4242
self.token_id_range_start = 0
4343
self.token_id_range_end = 0
4444
self.use_config_server = self.args.config_server_host and self.args.config_server_port
45-
self.cpu_embed_cache_client = CpuEmbedCacheClient(only_create_meta_data=True, init_shm_data=False)
45+
self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=True, init_shm_data=False)
4646

4747
def _check_and_set_new_id_range(self, alloced_token_num):
4848
need_update_range = self.token_id_range_start + alloced_token_num >= self.token_id_range_end

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,19 +19,11 @@
1919
from lightllm.models.qwen2_5_vl.qwen2_5_visual import Qwen2_5_VisionTransformerPretrainedModel
2020
from lightllm.models.qwen3_vl.qwen3_visual import Qwen3VisionTransformerPretrainedModel
2121
from lightllm.models.tarsier2.tarsier2_visual import TarsierVisionTransformerPretrainedModel
22-
from lightllm.server.embed_cache.utils import (
23-
tensor2bytes,
24-
read_shm,
25-
create_shm,
26-
create_shm_and_dump,
27-
get_shm_name_data,
28-
get_shm_name_embed,
29-
)
3022
from lightllm.utils.infer_utils import set_random_seed
31-
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
3223
from lightllm.utils.dist_utils import init_vision_distributed_env
3324
from lightllm.utils.graceful_utils import graceful_registry
3425
from lightllm.utils.envs_utils import get_env_start_args
26+
from lightllm.server.embed_cache.embed_cache_client import CpuEmbedCacheClient
3527

3628

3729
class VisualModelRpcServer(rpyc.Service):
@@ -91,6 +83,7 @@ def exposed_init_model(self, kvargs):
9183

9284
self.model.load_model(weight_dir)
9385
self.model = self.model.cuda()
86+
self.cpu_embed_cache_client = CpuEmbedCacheClient(create_meta_data=False, init_shm_data=True)
9487
except Exception as e:
9588
print("#" * 16)
9689
print("load model error:", str(e), e, type(e))
@@ -111,7 +104,7 @@ def forward(self, images: List[ImageItem]):
111104
def exposed_encode(self, images: List[ImageItem]):
112105
images = obtain(images)
113106
all_img_embeds, uuids, valid_ids = self.forward(images)
114-
all_img_embeds = all_img_embeds.to(torch.device("cpu"))
107+
all_img_embeds = all_img_embeds.to(torch.device("cuda"))
115108

116109
if self.tp_rank_id == 0:
117110
ready_flags = obtain(self.cache_client.root.get_items_embed(uuids))
@@ -121,9 +114,13 @@ def exposed_encode(self, images: List[ImageItem]):
121114
continue
122115
uid = uuids[i]
123116
start, end = valid_ids[i]
124-
create_shm_and_dump(get_shm_name_embed(uid), all_img_embeds[start:end])
117+
image = images[i]
118+
self.cpu_embed_cache_client.copy_to_cache(
119+
embed_tensor=all_img_embeds[start:end], start_index_in_cache=image.start_index_in_embed_cache
120+
)
125121
ids_to_set.append(uid)
126122
if ids_to_set:
123+
torch.cuda.current_stream().synchronize()
127124
self.cache_client.root.set_items_embed(ids_to_set)
128125
return
129126

0 commit comments

Comments
 (0)