From 40238103e6f25ff5e5df0a7297d8d351364a3b1c Mon Sep 17 00:00:00 2001 From: SmudgedWings <2045955563@qq.com> Date: Fri, 23 May 2025 15:29:38 +0800 Subject: [PATCH] update ci and support llava-next --- .github/workflows/main.yml | 3 +-- ci_check/runtime.txt | 36 ------------------------------------ llmc/models/llava_lht.py | 16 ++++++++++------ 3 files changed, 11 insertions(+), 44 deletions(-) delete mode 100644 ci_check/runtime.txt diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0ae98e9b8..80c27008b 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -26,8 +26,7 @@ jobs: - name: Install dependencies run: | - cd ci_check - pip install -r runtime.txt + pip install -r requirements.txt - name: Download dataset run: | diff --git a/ci_check/runtime.txt b/ci_check/runtime.txt deleted file mode 100644 index 7bb01c0e5..000000000 --- a/ci_check/runtime.txt +++ /dev/null @@ -1,36 +0,0 @@ -torch==2.7.0 -torchvision==0.22.0 -timm==1.0.15 -pillow==11.2.1 -loguru==0.7.3 -transformers==4.51.3 -huggingface-hub==0.30.2 -sentencepiece==0.1.99 -protobuf==3.20.0 -accelerate==1.6.0 -zstandard==0.23.0 -easydict==1.13 -evaluate==0.4.3 -datasets==2.16.1 -jsonlines==4.0.0 -numexpr==2.10.2 -peft==0.15.2 -pybind11==2.13.6 -pytablewriter==1.2.1 -rouge-score>=0.0.4 -sacrebleu==2.5.1 -scikit-learn==1.6.1 -sqlitedict==2.1.0 -tqdm-multiprocess==0.0.11 -dill==0.3.7 -word2number==1.1 -more_itertools -qtorch==0.3.0 -einops==0.8.1 -qwen-vl-utils==0.0.11 -tiktoken==0.9.0 -librosa==0.11.0 -human_eval -lmms-eval -imageio==2.37.0 -diffusers==0.33.1 diff --git a/llmc/models/llava_lht.py b/llmc/models/llava_lht.py index 8ee1c9b57..ffd3f6250 100644 --- a/llmc/models/llava_lht.py +++ b/llmc/models/llava_lht.py @@ -34,26 +34,27 @@ def build_tokenizer(self): pass def build_model(self): - self.llava_llama_config = LlavaConfig.from_pretrained( + self.llava_config = LlavaConfig.from_pretrained( self.model_path, trust_remote_code=True ) self.vlm_model_config = AutoConfig.from_pretrained( self.model_path, trust_remote_code=True ) if not self.use_cache: - self.llava_llama_config.use_cache = False + self.llava_config.use_cache = False self.vlm_model_config.use_cache = False logger.info(f'self.vlm_model_config : {self.vlm_model_config}') - self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model( + self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model( self.model_path, None, get_model_name_from_path(self.model_path), load_8bit=False, load_4bit=False, - torch_dtype=self.torch_dtype, device='cpu', - config=self.llava_llama_config, + torch_dtype=self.torch_dtype, + config=self.llava_config, ) + # llava-lht forward not support "cache_position" ori_forward = self.vlm_model.forward @@ -62,6 +63,7 @@ def safe_forward(*args, **kwargs): kwargs.pop('cache_position', None) return ori_forward(*args, **kwargs) self.vlm_model.forward = safe_forward + # llava-lht generate use "inputs" instead of "input_ids" ori_generate = self.vlm_model.generate @@ -190,7 +192,7 @@ def __init__( conv_template='vicuna_v1', use_cache: bool = False, tie_weights: bool = True, - truncate_context=False, # set it False for LLaVA-1.6 + truncate_context=False, # set it False for LLaVA-1.6 no matter truncate customized_config=None, # ends in json **kwargs, ) -> None: @@ -221,6 +223,7 @@ def __init__( if 'use_flash_attention_2' in kwargs: llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2'] model_name = model_name if model_name is not None else get_model_name_from_path(pretrained) + self._model = llmc_model.cuda() self._config = self._model.config self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False) @@ -240,6 +243,7 @@ def __init__( self._max_length = self._config.max_sequence_length else: self._max_length = 2048 + self.model.eval() if tie_weights: self.model.tie_weights()