-
Notifications
You must be signed in to change notification settings - Fork 296
Multimodal improve #951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Multimodal improve #951
Changes from 2 commits
d498aaf
691b89c
8624e8f
3e04c7f
06c38a0
d21ccaa
2ab9dfc
b0768e2
cbf93a0
f96c6ab
32df69c
d61da4d
ab02ccd
08e7701
0d112fb
257a732
ca83f87
96a0afb
eba4b00
88577b3
2497407
1742f3e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| from collections import OrderedDict | ||
| from lightllm.utils.dist_utils import get_current_device_id | ||
|
|
||
|
|
||
| class ImageCacheManager: | ||
| def __init__(self): | ||
| """ | ||
| Initialize the image cache manager with a simple GPU cache and an LRU CPU cache. | ||
| """ | ||
| self._gpu_cache = dict() | ||
| self._cpu_cache = OrderedDict() | ||
|
|
||
| def set_max_size(self, max_size: int): | ||
| """ | ||
| Set the maximum number of items to keep in the CPU cache. | ||
| :param max_size: Maximum number of items to keep in the CPU cache. | ||
| """ | ||
| if max_size <= 0: | ||
| raise ValueError("max_size must be greater than 0") | ||
| self._max_size = max_size | ||
|
|
||
| def set_embed(self, uuid, embed): | ||
| """ | ||
| Store the embedding for the given uuid in the GPU cache. | ||
| :param uuid: Unique identifier for the image | ||
| :param embed: Embedding vector for the image (on GPU) | ||
| """ | ||
| self._gpu_cache[uuid] = embed | ||
|
|
||
| def get_embed(self, uuid): | ||
| """ | ||
| Retrieve the embedding for the given uuid. Prefer GPU cache, | ||
| otherwise return CPU cache and move to GPU (simulate .cuda()). | ||
| :param uuid: Unique identifier for the image | ||
| :return: Embedding vector (on GPU if possible, else move from CPU to GPU) | ||
| """ | ||
| if uuid in self._gpu_cache: | ||
| return self._gpu_cache[uuid] | ||
| elif uuid in self._cpu_cache: | ||
| self._cpu_cache.move_to_end(uuid) | ||
| embed = self._cpu_cache[uuid].cuda(get_current_device_id()) | ||
| return embed | ||
| return None | ||
|
|
||
| def query_embed(self, uuid): | ||
| """ | ||
| Query if the embedding for the given uuid is in the cache. | ||
| :param uuid: Unique identifier for the image | ||
| :return: True if the embedding is in the cache, False otherwise | ||
| """ | ||
| return uuid in self._gpu_cache or uuid in self._cpu_cache | ||
|
|
||
| def filter(self, uuid_list): | ||
| """ | ||
| Given a list of uuids, move their embeddings from GPU cache to CPU cache if present, | ||
| and return a dict of those found in the cache and their embeddings (on CPU). | ||
| :param uuid_list: List of uuids | ||
| """ | ||
| for uuid in uuid_list: | ||
| if uuid in self._gpu_cache: | ||
| embed_cpu = self._gpu_cache[uuid].cpu() | ||
| # Move to CPU cache and remove from GPU cache | ||
| self._gpu_cache.pop(uuid) | ||
| if uuid in self._cpu_cache: | ||
| self._cpu_cache.move_to_end(uuid) | ||
| self._cpu_cache[uuid] = embed_cpu | ||
| if len(self._cpu_cache) > self._max_size: | ||
| self._cpu_cache.popitem(last=False) | ||
| elif uuid in self._cpu_cache: | ||
| self._cpu_cache.move_to_end(uuid) | ||
| print(self._gpu_cache.keys()) | ||
| print(self._cpu_cache.keys()) | ||
|
||
| return | ||
|
|
||
|
|
||
| image_cache_manager = ImageCacheManager() | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,7 +6,9 @@ | |||||||||||||
|
|
||||||||||||||
| from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer | ||||||||||||||
| from lightllm.utils.infer_utils import mark_cost_time | ||||||||||||||
| from lightllm.utils.envs_utils import get_env_start_args | ||||||||||||||
| from lightllm.server.embed_cache.utils import bytes2tensor, read_shm, get_shm_name_embed | ||||||||||||||
| from lightllm.common.image_cache_manager import image_cache_manager | ||||||||||||||
| from lightllm.common.basemodel.triton_kernel.multimodal_emb import multimodal_emb | ||||||||||||||
| from lightllm.distributed.communication_op import all_reduce | ||||||||||||||
|
|
||||||||||||||
|
|
@@ -29,8 +31,22 @@ | |||||||||||||
| class LlamaMultimodalPreLayerInfer(LlamaPreLayerInfer): | ||||||||||||||
| def __init__(self, network_config, mode): | ||||||||||||||
| super().__init__(network_config, mode) | ||||||||||||||
| self.disable_extra_process_for_multimodal = get_env_start_args().disable_extra_process_for_multimodal | ||||||||||||||
| return | ||||||||||||||
|
|
||||||||||||||
| def _infer_image_embeds(self, infer_state, layer_weight): | ||||||||||||||
| if not self.disable_extra_process_for_multimodal: | ||||||||||||||
| return | ||||||||||||||
| infer_images = [] | ||||||||||||||
| for _, p in enumerate(infer_state.multimodal_params): | ||||||||||||||
| for img in p["images"] + p["audios"]: | ||||||||||||||
| if not image_cache_manager.query_embed(img["uuid"]): | ||||||||||||||
| infer_images.append(img) | ||||||||||||||
| if len(infer_images) > 0: | ||||||||||||||
| img_embeds, uuids, valid_ids = layer_weight.visual_model.encode(infer_images) | ||||||||||||||
| for uuid, valid_id in zip(uuids, valid_ids): | ||||||||||||||
| image_cache_manager.set_embed(uuid, img_embeds[valid_id[0] : valid_id[1]]) | ||||||||||||||
|
|
||||||||||||||
| def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): | ||||||||||||||
|
|
||||||||||||||
| img_weight = [] | ||||||||||||||
|
|
@@ -42,14 +58,19 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei | |||||||||||||
| device = layer_weight.wte_weight_.device | ||||||||||||||
| dtype = layer_weight.wte_weight_.dtype | ||||||||||||||
| hidden_size = layer_weight.wte_weight_.shape[1] | ||||||||||||||
| self._infer_image_embeds(infer_state, layer_weight) | ||||||||||||||
| for batch_id, p in enumerate(infer_state.multimodal_params): | ||||||||||||||
| for img in p["images"] + p["audios"]: | ||||||||||||||
| # skip the same image | ||||||||||||||
| if img["token_id"] in img_start_token_ids: | ||||||||||||||
| continue | ||||||||||||||
| # pull the img_embeds by uid from shm | ||||||||||||||
| data = read_shm(get_shm_name_embed(img["uuid"])) | ||||||||||||||
| img_weight.append(bytes2tensor(data).cuda().reshape(img["token_num"], -1)) | ||||||||||||||
| if self.disable_extra_process_for_multimodal: | ||||||||||||||
| img_embed = image_cache_manager.get_embed(img["uuid"]) | ||||||||||||||
| img_weight.append(img_embed.reshape(img["token_num"], -1)) | ||||||||||||||
|
||||||||||||||
| img_embed = image_cache_manager.get_embed(img["uuid"]) | |
| img_weight.append(img_embed.reshape(img["token_num"], -1)) | |
| img_embed = image_cache_manager.get_embed(img["uuid"]) | |
| if img_embed is None: | |
| raise ValueError(f"Image embedding for uuid {img['uuid']} not found in cache.") | |
| img_weight.append(img_embed.reshape(img["token_num"], -1)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ def __init__(self, network_config, mode): | |
| self.tp_world_size_ = get_dp_world_size() | ||
| self.network_config_ = network_config | ||
| self.mode = mode | ||
| print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") | ||
|
||
| return | ||
|
|
||
| def forward(self, pixel_values, layer_weight: ViTPreAndPostLayerWeight): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,6 +21,7 @@ class ViTTransformerLayerInfer: | |
| def __init__(self, layer_num, network_config, mode=[]): | ||
| self.tp_rank_ = get_current_rank_in_dp() | ||
| self.tp_world_size_ = get_dp_world_size() | ||
| print(f"tp_rank_: {self.tp_rank_}, tp_world_size_: {self.tp_world_size_}") | ||
|
||
| self.eps_ = network_config["layer_norm_eps"] | ||
| self.head_num = network_config["num_attention_heads"] | ||
| self.tp_padding_head_num = network_config["padding_head_num"] // self.tp_world_size_ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,7 +18,8 @@ | |
| from io import BytesIO | ||
| from rpyc.utils.classic import obtain | ||
| from lightllm.common.quantization import Quantcfg | ||
| from lightllm.utils.dist_utils import get_dp_world_size | ||
| from lightllm.utils.dist_utils import get_dp_world_size, get_global_world_size | ||
| from lightllm.utils.envs_utils import get_env_start_args | ||
| from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager | ||
|
|
||
|
|
||
|
|
@@ -150,6 +151,8 @@ def _init_infer_layer(self): | |
| return | ||
|
|
||
| def _init_datatype(self): | ||
| if isinstance(self.data_type, torch.dtype): | ||
| return | ||
| if self.data_type in ["fp16", "float16"]: | ||
| self.data_type = torch.float16 | ||
| elif self.data_type in ["bf16", "bfloat16"]: | ||
|
|
@@ -161,12 +164,14 @@ def _init_datatype(self): | |
|
|
||
| @torch.no_grad() | ||
| def forward(self, pixel_values): | ||
| g_cache_manager.cache_env_in() | ||
| if not get_env_start_args().disable_extra_process_for_multimodal: | ||
| g_cache_manager.cache_env_in() | ||
| input_embs = self.pre_infer.forward(pixel_values, self.pre_post_weight) | ||
| for i in range(self.layers_num + self.select_layer + 1): | ||
| input_embs = self.layers_infer[i].forward(input_embs, self.trans_layers_weight[i]) | ||
| input_embs = self.post_infer.forward(input_embs[:, 1:, :], self.pre_post_weight) | ||
| g_cache_manager.cache_env_out() | ||
| if not get_env_start_args().disable_extra_process_for_multimodal: | ||
| g_cache_manager.cache_env_out() | ||
| return input_embs | ||
|
|
||
| @torch.no_grad() | ||
|
|
@@ -182,6 +187,12 @@ def encode(self, images: List[ImageItem]): | |
| image_data = Image.open(BytesIO(image_data)) | ||
| t = self.load_image_func(image_data, max_num=img.extra_params["image_patch_max_num"]) | ||
| img_tensors.append(t) | ||
| elif isinstance(img, dict): | ||
| uuids.append(img["uuid"]) | ||
| image_data = read_shm(get_shm_name_data(img["uuid"])) | ||
| image_data = Image.open(BytesIO(image_data)) | ||
| t = self.load_image_func(image_data, max_num=img["extra_params"]["image_patch_max_num"]) | ||
| img_tensors.append(t) | ||
|
Comment on lines
+201
to
+206
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| else: | ||
| raise Exception("Unsupport input types: {} for {}".format(type(img), img)) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
_max_sizeattribute is used in thefiltermethod but is not initialized in the__init__method, which can lead to anAttributeErroriffilter()is called beforeset_max_size().