Skip to content

Commit d498aaf

Browse files
committed
vit and llm inference at a single process
1 parent f7dfc16 commit d498aaf

File tree

10 files changed

+215
-14
lines changed

10 files changed

+215
-14
lines changed
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from collections import OrderedDict
2+
from lightllm.utils.dist_utils import get_current_device_id
3+
4+
5+
class ImageCacheManager:
6+
def __init__(self):
7+
"""
8+
Initialize the image cache manager with a simple GPU cache and an LRU CPU cache.
9+
"""
10+
self._gpu_cache = dict()
11+
self._cpu_cache = OrderedDict()
12+
13+
def set_max_size(self, max_size: int):
14+
"""
15+
Set the maximum number of items to keep in the CPU cache.
16+
:param max_size: Maximum number of items to keep in the CPU cache.
17+
"""
18+
if max_size <= 0:
19+
raise ValueError("max_size must be greater than 0")
20+
self._max_size = max_size
21+
22+
def set_embed(self, uuid, embed):
23+
"""
24+
Store the embedding for the given uuid in the GPU cache.
25+
:param uuid: Unique identifier for the image
26+
:param embed: Embedding vector for the image (on GPU)
27+
"""
28+
self._gpu_cache[uuid] = embed
29+
30+
def get_embed(self, uuid):
31+
"""
32+
Retrieve the embedding for the given uuid. Prefer GPU cache,
33+
otherwise return CPU cache and move to GPU (simulate .cuda()).
34+
:param uuid: Unique identifier for the image
35+
:return: Embedding vector (on GPU if possible, else move from CPU to GPU)
36+
"""
37+
if uuid in self._gpu_cache:
38+
return self._gpu_cache[uuid]
39+
elif uuid in self._cpu_cache:
40+
self._cpu_cache.move_to_end(uuid)
41+
embed = self._cpu_cache[uuid].cuda(get_current_device_id())
42+
return embed
43+
return None
44+
45+
def query_embed(self, uuid):
46+
"""
47+
Query if the embedding for the given uuid is in the cache.
48+
:param uuid: Unique identifier for the image
49+
:return: True if the embedding is in the cache, False otherwise
50+
"""
51+
return uuid in self._gpu_cache or uuid in self._cpu_cache
52+
53+
def filter(self, uuid_list):
54+
"""
55+
Given a list of uuids, move their embeddings from GPU cache to CPU cache if present,
56+
and return a dict of those found in the cache and their embeddings (on CPU).
57+
:param uuid_list: List of uuids
58+
"""
59+
for uuid in uuid_list:
60+
if uuid in self._gpu_cache:
61+
embed_cpu = self._gpu_cache[uuid].cpu(non_blocking=True)
62+
# Move to CPU cache and remove from GPU cache
63+
self._gpu_cache.pop(uuid)
64+
if uuid in self._cpu_cache:
65+
self._cpu_cache.move_to_end(uuid)
66+
self._cpu_cache[uuid] = embed_cpu
67+
if len(self._cpu_cache) > self._max_size:
68+
self._cpu_cache.popitem(last=False)
69+
elif uuid in self._cpu_cache:
70+
self._cpu_cache.move_to_end(uuid)
71+
return
72+
73+
74+
image_cache_manager = ImageCacheManager()

lightllm/models/internvl/layer_weights/pre_and_post_layer_weight.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
44

55
from lightllm.models.internlm2.layer_weights.pre_and_post_layer_weight import Internlm2PreAndPostLayerWeight
6+
from lightllm.models.vit.model import VisionTransformer
7+
from lightllm.utils.envs_utils import get_env_start_args
8+
from lightllm.common.image_cache_manager import image_cache_manager
69

710

811
# add key: language_model.xxx -> xxx
@@ -15,9 +18,45 @@ def rename_weight_keys(weights):
1518
weights[k[len(prefix) :]] = weights[k]
1619

1720

21+
class InternVLPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
22+
def __init__(self, data_type, network_config, mode):
23+
super().__init__(data_type, network_config, mode)
24+
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
25+
if get_env_start_args().disable_extra_process_for_multimodal:
26+
kvargs = {
27+
"weight_dir": get_env_start_args().model_dir,
28+
"data_type": self.data_type_,
29+
"quant_type": get_env_start_args().vit_quant_type,
30+
"quant_cfg": get_env_start_args().vit_quant_cfg,
31+
"max_batch_size": get_env_start_args().visual_infer_batch_size,
32+
}
33+
self.visual_model = VisionTransformer(
34+
kvargs=kvargs,
35+
)
36+
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
37+
return
38+
39+
def load_hf_weights(self, weights):
40+
rename_weight_keys(weights)
41+
super().load_hf_weights(weights)
42+
43+
1844
class InternVLPhi3PreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
1945
def __init__(self, data_type, network_config, mode):
2046
super().__init__(data_type, network_config, mode)
47+
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
48+
if get_env_start_args().disable_extra_process_for_multimodal:
49+
kvargs = {
50+
"weight_dir": get_env_start_args().model_dir,
51+
"data_type": self.data_type_,
52+
"quant_type": get_env_start_args().vit_quant_type,
53+
"quant_cfg": get_env_start_args().vit_quant_cfg,
54+
"max_batch_size": get_env_start_args().visual_infer_batch_size,
55+
}
56+
self.visual_model = VisionTransformer(
57+
kvargs=kvargs,
58+
)
59+
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
2160
return
2261

2362
def load_hf_weights(self, weights):
@@ -29,6 +68,19 @@ def load_hf_weights(self, weights):
2968
class InternVLInternlm2PreAndPostLayerWeight(Internlm2PreAndPostLayerWeight):
3069
def __init__(self, data_type, network_config, mode):
3170
super().__init__(data_type, network_config, mode)
71+
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
72+
if get_env_start_args().disable_extra_process_for_multimodal:
73+
kvargs = {
74+
"weight_dir": get_env_start_args().model_dir,
75+
"data_type": self.data_type_,
76+
"quant_type": get_env_start_args().vit_quant_type,
77+
"quant_cfg": get_env_start_args().vit_quant_cfg,
78+
"max_batch_size": get_env_start_args().visual_infer_batch_size,
79+
}
80+
self.visual_model = VisionTransformer(
81+
kvargs=kvargs,
82+
)
83+
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
3284
return
3385

3486
def load_hf_weights(self, weights):
@@ -40,6 +92,19 @@ def load_hf_weights(self, weights):
4092
class InternVLLlamaPreAndPostLayerWeight(LlamaPreAndPostLayerWeight):
4193
def __init__(self, data_type, network_config, mode):
4294
super().__init__(data_type, network_config, mode)
95+
# if we don't assign an extra process for visual model, we need initialize the image cache manager here
96+
if get_env_start_args().disable_extra_process_for_multimodal:
97+
kvargs = {
98+
"weight_dir": get_env_start_args().model_dir,
99+
"data_type": self.data_type_,
100+
"quant_type": get_env_start_args().vit_quant_type,
101+
"quant_cfg": get_env_start_args().vit_quant_cfg,
102+
"max_batch_size": get_env_start_args().visual_infer_batch_size,
103+
}
104+
self.visual_model = VisionTransformer(
105+
kvargs=kvargs,
106+
)
107+
image_cache_manager.set_max_size(get_env_start_args().cache_capacity * 2)
43108
return
44109

45110
def load_hf_weights(self, weights):

lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66

77
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
88
from lightllm.utils.infer_utils import mark_cost_time
9+
from lightllm.utils.envs_utils import get_env_start_args
910
from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed
11+
from lightllm.common.image_cache_manager import image_cache_manager
1012
from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb
1113
from lightllm.distributed.communication_op import all_reduce
1214

@@ -29,8 +31,22 @@
2931
class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer):
3032
def __init__(self, network_config, mode):
3133
super().__init__(network_config, mode)
34+
self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal
3235
return
3336

37+
def _infer_image_embeds(self, infer_state, layer_weight):
38+
if not self.disable_extra_process_for_multimodal:
39+
return
40+
infer_images = []
41+
for _, p in enumerate(infer_state.multimodal_params):
42+
for img in p["images"] + p["audios"]:
43+
if not image_cache_manager.query_embed(img["uuid"]):
44+
infer_images.append(img)
45+
if len(infer_images) > 0:
46+
img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images)
47+
for uuid, valid_id in zip(uuids, valid_ids):
48+
image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]])
49+
3450
def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
3551

3652
img_weight = []
@@ -42,14 +58,20 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei
4258
device = layer_weight.wte_weight_.device
4359
dtype = layer_weight.wte_weight_.dtype
4460
hidden_size = layer_weight.wte_weight_.shape[1]
61+
self._infer_image_embeds(infer_state, layer_weight)
4562
for batch_id, p in enumerate(infer_state.multimodal_params):
4663
for img in p["images"] + p["audios"]:
4764
# skip the same image
4865
if img["token_id"] in img_start_token_ids:
4966
continue
5067
# pull the img_embeds by uid from shm
51-
data = read_shm(get_shm_name_embed(img["uuid"]))
52-
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
68+
if self.disable_extra_process_for_multimodal:
69+
img_embed = image_cache_manager.get_embed(img["uuid"])
70+
img_weight.append(img_embed.reshape(img["token_num"], -1))
71+
print(img_weight[-1].shape)
72+
else:
73+
data = read_shm(get_shm_name_embed(img["uuid"]))
74+
img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1))
5375
img_start_token_ids.append(img["token_id"])
5476
img_token_lens.append(img["token_num"])
5577
img_start_locs.append(img_start_loc)

lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,12 @@
33
import numpy as np
44
import torch.nn.functional as F
55
from lightllm.common.basemodel import PreAndPostLayerWeight
6-
from lightllm.utils.dist_utils import get_current_device_id
6+
from lightllm.utils.dist_utils import (
7+
get_current_device_id,
8+
get_global_rank,
9+
get_global_world_size,
10+
)
11+
from lightllm.utils.envs_utils import get_env_start_args
712

813

914
class ViTPreAndPostLayerWeight(PreAndPostLayerWeight):
@@ -13,6 +18,10 @@ def __init__(self, data_type, network_config, mode):
1318
self.image_size = self.network_config_["image_size"]
1419
self.patch_size = self.network_config_["patch_size"]
1520
self.llm_hidden_size = self.network_config_["llm_hidden_size"]
21+
if get_env_start_args().disable_extra_process_for_multimodal:
22+
self.tp_world_size_ = get_global_world_size()
23+
self.tp_rank_ = get_global_rank()
24+
1625
return
1726

1827
def _cuda(self, cpu_tensor):

lightllm/models/vit/layer_weights/transformer_layer_weight.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,20 @@
1111
MultiROWMMWeight,
1212
TpNormWeight,
1313
)
14-
from lightllm.utils.dist_utils import get_current_device_id
14+
from lightllm.utils.dist_utils import (
15+
get_current_device_id,
16+
get_global_rank,
17+
get_global_world_size,
18+
)
19+
from lightllm.utils.envs_utils import get_env_start_args
1520

1621

1722
class ViTTransformerLayerWeight(TransformerLayerWeight):
1823
def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None):
1924
super().__init__(layer_num, data_type, network_config, mode, quant_cfg)
25+
if get_env_start_args().disable_extra_process_for_multimodal:
26+
self.tp_world_size_ = get_global_world_size()
27+
self.tp_rank_ = get_global_rank()
2028
return
2129

2230
def _cuda(self, cpu_tensor):

lightllm/models/vit/model.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
from io import BytesIO
1919
from rpyc.utils.classic import obtain
2020
from lightllm.common.quantization import Quantcfg
21-
from lightllm.utils.dist_utils import get_dp_world_size
21+
from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size
22+
from lightllm.utils.envs_utils import get_env_start_args
2223
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2324

2425

@@ -37,7 +38,11 @@ class VisionTransformer:
3738
post_layer_infer_class = ViTPostLayerInfer
3839

3940
def __init__(self, kvargs):
40-
self.tp_world_size_ = get_dp_world_size()
41+
if get_env_start_args().disable_extra_process_for_multimodal:
42+
# if we don't assign an extra process for visual model, the visual model uses tensor parallel by default.
43+
self.tp_world_size_ = get_global_world_size()
44+
else:
45+
self.tp_world_size_ = get_dp_world_size()
4146
self.weight_dir_ = kvargs["weight_dir"]
4247
self.load_way = kvargs.get("load_way", "HF")
4348
self.mode = [m.replace("int4weight", "w4a16").replace("int8weight", "w8a16") for m in kvargs.get("mode", [])]
@@ -150,6 +155,8 @@ def _init_infer_layer(self):
150155
return
151156

152157
def _init_datatype(self):
158+
if isinstance(self.data_type, torch.dtype):
159+
return
153160
if self.data_type in ["fp16", "float16"]:
154161
self.data_type = torch.float16
155162
elif self.data_type in ["bf16", "bfloat16"]:
@@ -161,12 +168,14 @@ def _init_datatype(self):
161168

162169
@torch.no_grad()
163170
def forward(self, pixel_values):
164-
g_cache_manager.cache_env_in()
171+
if not get_env_start_args().disable_extra_process_for_multimodal:
172+
g_cache_manager.cache_env_in()
165173
input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight)
166174
for i in range(self.layers_num + self.select_layer + 1):
167175
input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i])
168176
input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight)
169-
g_cache_manager.cache_env_out()
177+
if not get_env_start_args().disable_extra_process_for_multimodal:
178+
g_cache_manager.cache_env_out()
170179
return input_embs
171180

172181
@torch.no_grad()
@@ -182,6 +191,12 @@ def encode(self, images: List[ImageItem]):
182191
image_data = Image.open(BytesIO(image_data))
183192
t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"])
184193
img_tensors.append(t)
194+
elif isinstance(img, dict):
195+
uuids.append(img["uuid"])
196+
image_data = read_shm(get_shm_name_data(img["uuid"]))
197+
image_data = Image.open(BytesIO(image_data))
198+
t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"])
199+
img_tensors.append(t)
185200
else:
186201
raise Exception("Unsupport input types: {} for {}".format(type(img), img))
187202

lightllm/server/api_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
233233
parser.add_argument(
234234
"--enable_multimodal", action="store_true", help="Whether or not to allow to load additional visual models."
235235
)
236+
parser.add_argument(
237+
"--disable_extra_process_for_multimodal",
238+
action="store_true",
239+
help="Whether or not to disable extra process for multimodal.",
240+
)
236241
parser.add_argument(
237242
"--enable_multimodal_audio",
238243
action="store_true",

lightllm/server/api_start.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def normal_or_p_d_start(args):
243243
],
244244
start_args=[(cache_port, args)],
245245
)
246-
if args.enable_multimodal_audio:
246+
if args.enable_multimodal_audio and not args.disable_extra_process_for_multimodal:
247247
from .audioserver.manager import start_audio_process
248248

249249
process_manager.start_submodule_processes(
@@ -263,7 +263,7 @@ def normal_or_p_d_start(args):
263263
],
264264
)
265265

266-
else:
266+
elif not args.disable_extra_process_for_multimodal:
267267
process_manager.start_submodule_processes(
268268
start_funcs=[
269269
start_visual_process,

lightllm/server/httpserver/manager.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,12 @@ def __init__(
8181
)
8282

8383
self.enable_multimodal = enable_multimodal
84+
self.disable_extra_process_for_multimodal = args.disable_extra_process_for_multimodal
8485
if self.enable_multimodal:
8586
self.cache_client = rpyc.connect("localhost", cache_port)
86-
self.send_to_visual = context.socket(zmq.PUSH)
87-
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
87+
if not self.disable_extra_process_for_multimodal:
88+
self.send_to_visual = context.socket(zmq.PUSH)
89+
self.send_to_visual.connect(f"{args.zmq_mode}127.0.0.1:{visual_port}")
8890

8991
self.shm_req_manager = ShmReqManager()
9092

@@ -449,7 +451,7 @@ async def transfer_to_next_module(
449451
):
450452

451453
if self.pd_mode == NodeRole.P:
452-
if self.enable_multimodal:
454+
if self.enable_multimodal and not self.disable_extra_process_for_multimodal:
453455
self.send_to_visual.send_pyobj(
454456
group_req_objs.to_group_req_index(),
455457
protocol=pickle.HIGHEST_PROTOCOL,
@@ -470,7 +472,7 @@ async def transfer_to_next_module(
470472
return
471473

472474
if self.pd_mode == NodeRole.NORMAL:
473-
if self.enable_multimodal:
475+
if self.enable_multimodal and not self.disable_extra_process_for_multimodal:
474476
self.send_to_visual.send_pyobj(
475477
group_req_objs.to_group_req_index(),
476478
protocol=pickle.HIGHEST_PROTOCOL,

0 commit comments

Comments
 (0)