Skip to content

Commit 6f3755c

Browse files
update ci and support llava-next
update ci and support llava-next support llava-next(only Vicuna)
2 parents 671c271 + 4023810 commit 6f3755c

File tree

3 files changed

+11
-44
lines changed

3 files changed

+11
-44
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,7 @@ jobs:
2626

2727
- name: Install dependencies
2828
run: |
29-
cd ci_check
30-
pip install -r runtime.txt
29+
pip install -r requirements.txt
3130
3231
- name: Download dataset
3332
run: |

ci_check/runtime.txt

Lines changed: 0 additions & 36 deletions
This file was deleted.

llmc/models/llava_lht.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,26 +34,27 @@ def build_tokenizer(self):
3434
pass
3535

3636
def build_model(self):
37-
self.llava_llama_config = LlavaConfig.from_pretrained(
37+
self.llava_config = LlavaConfig.from_pretrained(
3838
self.model_path, trust_remote_code=True
3939
)
4040
self.vlm_model_config = AutoConfig.from_pretrained(
4141
self.model_path, trust_remote_code=True
4242
)
4343
if not self.use_cache:
44-
self.llava_llama_config.use_cache = False
44+
self.llava_config.use_cache = False
4545
self.vlm_model_config.use_cache = False
4646
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
47-
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
47+
self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
4848
self.model_path,
4949
None,
5050
get_model_name_from_path(self.model_path),
5151
load_8bit=False,
5252
load_4bit=False,
53-
torch_dtype=self.torch_dtype,
5453
device='cpu',
55-
config=self.llava_llama_config,
54+
torch_dtype=self.torch_dtype,
55+
config=self.llava_config,
5656
)
57+
5758
# llava-lht forward not support "cache_position"
5859
ori_forward = self.vlm_model.forward
5960

@@ -62,6 +63,7 @@ def safe_forward(*args, **kwargs):
6263
kwargs.pop('cache_position', None)
6364
return ori_forward(*args, **kwargs)
6465
self.vlm_model.forward = safe_forward
66+
6567
# llava-lht generate use "inputs" instead of "input_ids"
6668
ori_generate = self.vlm_model.generate
6769

@@ -190,7 +192,7 @@ def __init__(
190192
conv_template='vicuna_v1',
191193
use_cache: bool = False,
192194
tie_weights: bool = True,
193-
truncate_context=False, # set it False for LLaVA-1.6
195+
truncate_context=False, # set it False for LLaVA-1.6 no matter truncate
194196
customized_config=None, # ends in json
195197
**kwargs,
196198
) -> None:
@@ -221,6 +223,7 @@ def __init__(
221223
if 'use_flash_attention_2' in kwargs:
222224
llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2']
223225
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)
226+
224227
self._model = llmc_model.cuda()
225228
self._config = self._model.config
226229
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False)
@@ -240,6 +243,7 @@ def __init__(
240243
self._max_length = self._config.max_sequence_length
241244
else:
242245
self._max_length = 2048
246+
243247
self.model.eval()
244248
if tie_weights:
245249
self.model.tie_weights()

0 commit comments

Comments
 (0)