diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 85d6ffcc5..7a495c0db 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -10,6 +10,7 @@ from .internvl2 import InternVL2 from .llama import Llama from .llava import Llava +from .llava_lht import LlavaLHT from .minicpm import MiniCPM from .minicpmv import MiniCPMV from .mistral import Mistral diff --git a/llmc/models/llava_lht.py b/llmc/models/llava_lht.py new file mode 100644 index 000000000..8ee1c9b57 --- /dev/null +++ b/llmc/models/llava_lht.py @@ -0,0 +1,300 @@ +import types +from datetime import timedelta +from typing import Optional, Union + +import torch +from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs +from accelerate.state import AcceleratorState +from lmms_eval.api.model import lmms +from lmms_eval.models.llava import Llava +from loguru import logger +from packaging import version +from transformers import AutoConfig, AutoTokenizer + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .llama import Llama + +try: + from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, + DEFAULT_IMAGE_PATCH_TOKEN) + from llava.mm_utils import get_model_name_from_path + from llava.model.builder import load_pretrained_model + from llava.model.language_model.llava_llama import LlavaConfig +except Exception as e: + logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e) + + +@MODEL_REGISTRY +class LlavaLHT(Llama): + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def build_tokenizer(self): + pass + + def build_model(self): + self.llava_llama_config = LlavaConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + self.vlm_model_config = AutoConfig.from_pretrained( + self.model_path, trust_remote_code=True + ) + if not self.use_cache: + self.llava_llama_config.use_cache = False + self.vlm_model_config.use_cache = False + logger.info(f'self.vlm_model_config : {self.vlm_model_config}') + self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model( + self.model_path, + None, + get_model_name_from_path(self.model_path), + load_8bit=False, + load_4bit=False, + torch_dtype=self.torch_dtype, + device='cpu', + config=self.llava_llama_config, + ) + # llava-lht forward not support "cache_position" + ori_forward = self.vlm_model.forward + + def safe_forward(*args, **kwargs): + kwargs['use_cache'] = False + kwargs.pop('cache_position', None) + return ori_forward(*args, **kwargs) + self.vlm_model.forward = safe_forward + # llava-lht generate use "inputs" instead of "input_ids" + ori_generate = self.vlm_model.generate + + def safe_generate(*args, **kwargs): + if 'input_ids' in kwargs: + kwargs['inputs'] = kwargs.pop('input_ids') + return ori_generate(*args, **kwargs) + self.vlm_model.generate = safe_generate + + # "attention_mask" is passed via kwargs rather than as an explicit keyword argument. + ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation + + def safe_prepare_inputs_for_generation( + self, input_ids, past_key_values=None, + inputs_embeds=None, attention_mask=None, **kwargs): + if attention_mask is not None: + kwargs['attention_mask'] = attention_mask + return ori_prepare_inputs_for_generation( + input_ids, past_key_values, inputs_embeds, **kwargs) + self.vlm_model.prepare_inputs_for_generation = types.MethodType( + safe_prepare_inputs_for_generation, self.vlm_model + ) + + self.eval_name = 'LlavaLHTEval' + self.mm_model = self.vlm_model + logger.info(f'self.vlm_model : {self.vlm_model}') + self.vision_model = self.vlm_model.get_vision_tower() + self.vision_projector = self.vlm_model.model.mm_projector + # Llava-lht merges the language model with the vision projector and vision model + self.model = self.vlm_model + self.model_config = self.vlm_model_config.text_config + self.pruning_config = { + 'image_token_start_index': 5, + 'image_token_length': self.vlm_model_config.image_seq_length, + 'select_layer': self.vlm_model_config.vision_feature_layer, + 'select_feature': self.vlm_model_config.vision_feature_select_strategy, + 'image_token_index': self.vlm_model_config.image_token_index + } + self.processor = None + + def get_extra_rot_module_besides_embed_layers(self): + return [self.vision_projector[2]] + + def find_blocks(self): + if self.get_modality() == 'language': + super().find_blocks() + elif self.get_modality() == 'vision': + self.blocks = self.vision_model.vision_tower.vision_model.encoder.layers + else: + raise Exception(f'Llava do not support {self.get_modality()} modality.') + + def get_layernorms_in_block(self, block): + if self.get_modality() == 'language': + return super().get_layernorms_in_block(block) + elif self.get_modality() == 'vision': + return { + 'layer_norm1': block.layer_norm1, + 'layer_norm2': block.layer_norm2, + } + else: + raise Exception(f'Llava do not support {self.get_modality()} modality.') + + def get_subsets_in_block(self, block): + if self.get_modality() == 'language': + return super().get_subsets_in_block(block) + elif self.get_modality() == 'vision': + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.layer_norm1], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.out_proj': block.self_attn.out_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.out_proj'], + 'inspect': block.self_attn.out_proj, + 'has_kwargs': False, + }, + { + 'layers': {'mlp.fc1': block.mlp.fc1}, + 'prev_op': [block.layer_norm2], + 'input': ['mlp.fc1'], + 'inspect': block.mlp.fc1, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.fc2': block.mlp.fc2}, + 'prev_op': [block.mlp.fc1], + 'input': ['mlp.fc2'], + 'inspect': block.mlp.fc2, + 'has_kwargs': False, + 'is_mlp': True, + 'do_trans': False + }, + ] + else: + raise Exception(f'Llava do not support {self.get_modality()} modality.') + + +if version.parse(torch.__version__) >= version.parse('2.1.2'): + best_fit_attn_implementation = 'sdpa' +else: + best_fit_attn_implementation = 'eager' + + +@MODEL_REGISTRY +class LlavaLHTEval(Llava): + def __init__( + self, + llmc_model, + pretrained: str = 'liuhaotian/llava-v1.5-7b', + truncation: Optional[bool] = True, + device: Optional[str] = 'cuda', + batch_size: Optional[Union[int, str]] = 1, + model_name=None, + attn_implementation=best_fit_attn_implementation, + device_map: str = '', + conv_template='vicuna_v1', + use_cache: bool = False, + tie_weights: bool = True, + truncate_context=False, # set it False for LLaVA-1.6 + customized_config=None, # ends in json + **kwargs, + ) -> None: + lmms.__init__(self) + # Do not use kwargs for now + assert kwargs == {}, f'Unexpected kwargs: {kwargs}' + + accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) + accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs]) + self.accelerator = accelerator + if accelerator.num_processes > 1: + self._device = torch.device(f'cuda:{accelerator.local_process_index}') + self.device_map = f'cuda:{accelerator.local_process_index}' + elif accelerator.num_processes == 1 and device_map == 'auto': + self._device = torch.device(device) + self.device_map = device_map + else: + self._device = torch.device(f'cuda:{accelerator.local_process_index}') + self.device_map = f'cuda:{accelerator.local_process_index}' + + llava_model_args = { + 'multimodal': True, + } + if customized_config is not None: + llava_model_args['customized_config'] = customized_config + if attn_implementation is not None: + llava_model_args['attn_implementation'] = attn_implementation + if 'use_flash_attention_2' in kwargs: + llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2'] + model_name = model_name if model_name is not None else get_model_name_from_path(pretrained) + self._model = llmc_model.cuda() + self._config = self._model.config + self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False) + self._image_processor = None + if 'llava' in model_name.lower(): + mm_use_im_start_end = getattr(self._config, 'mm_use_im_start_end', False) + mm_use_im_patch_token = getattr(self._config, 'mm_use_im_patch_token', True) + if mm_use_im_patch_token: + self._tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True) + if mm_use_im_start_end: + self._tokenizer.add_tokens( + [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], + special_tokens=True + ) + self._image_processor = self._model.get_vision_tower().image_processor + if hasattr(self._config, 'max_sequence_length'): + self._max_length = self._config.max_sequence_length + else: + self._max_length = 2048 + self.model.eval() + if tie_weights: + self.model.tie_weights() + + self.truncation = truncation + self.batch_size_per_gpu = int(batch_size) + self.conv_template = conv_template + self.use_cache = use_cache + self.truncate_context = truncate_context + # assert self.batch_size_per_gpu == 1, ( + # "Llava currently does not support batched generation. " + # "See: https://github.com/haotian-liu/LLaVA/issues/754. " + # "HF Llava also has this issue." + # ) + if accelerator.num_processes > 1: + assert accelerator.distributed_type in [ + DistributedType.FSDP, + DistributedType.MULTI_GPU, + DistributedType.DEEPSPEED], ( + 'Unsupported distributed type provided. ' + 'Only DDP and FSDP are supported.') + # To use DistributedType.DEEPSPEED, run `accelerate config` first. + # You must select zero stage 0 (equivalent to DDP) for model preparation to work. + # Attempts to support zero stage 2 via kwargs failed. + if accelerator.distributed_type == DistributedType.DEEPSPEED: + kwargs = { + 'train_micro_batch_size_per_gpu': self.batch_size_per_gpu, + 'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes, + } + AcceleratorState().deepspeed_plugin.deepspeed_config_process( + must_match=True, **kwargs + ) + logger.info( + 'Detected that you are using DistributedType.DEEPSPEED. ' + 'Make sure you run `accelerate config` and set zero stage to 0' + ) + + if ( + accelerator.distributed_type == DistributedType.FSDP + or accelerator.distributed_type == DistributedType.DEEPSPEED + ): + self._model = accelerator.prepare(self.model) + else: + self._model = accelerator.prepare_model(self.model, evaluation_mode=True) + self.accelerator = accelerator + if self.accelerator.is_local_main_process: + logger.info(f'Using {accelerator.num_processes} devices with data parallelism') + self._rank = self.accelerator.local_process_index + self._world_size = self.accelerator.num_processes + elif accelerator.num_processes == 1 and device_map == 'auto': + logger.info(f'Using {accelerator.num_processes} devices with tensor parallelism') + self._rank = 0 + self._word_size = 1 + else: + logger.info(f'Using single device: {self._device}') + self.model.to(self._device) + self._rank = 0 + self._world_size = 1 diff --git a/requirements/runtime.txt b/requirements/runtime.txt index d7d132aee..66c367926 100755 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -1,9 +1,10 @@ -torch>=2.1.0 +torch>=2.2.0 torchvision timm pillow loguru transformers>=4.45.2 +lmms-eval huggingface-hub sentencepiece protobuf @@ -31,6 +32,5 @@ qwen-vl-utils tiktoken librosa human_eval -lmms-eval imageio diffusers \ No newline at end of file