@@ -34,26 +34,27 @@ def build_tokenizer(self):
3434 pass
3535
3636 def build_model (self ):
37- self .llava_llama_config = LlavaConfig .from_pretrained (
37+ self .llava_config = LlavaConfig .from_pretrained (
3838 self .model_path , trust_remote_code = True
3939 )
4040 self .vlm_model_config = AutoConfig .from_pretrained (
4141 self .model_path , trust_remote_code = True
4242 )
4343 if not self .use_cache :
44- self .llava_llama_config .use_cache = False
44+ self .llava_config .use_cache = False
4545 self .vlm_model_config .use_cache = False
4646 logger .info (f'self.vlm_model_config : { self .vlm_model_config } ' )
47- self .tokenizer , self .vlm_model , self . image_processor , context_len = load_pretrained_model (
47+ self .tokenizer , self .vlm_model , image_processor , context_len = load_pretrained_model (
4848 self .model_path ,
4949 None ,
5050 get_model_name_from_path (self .model_path ),
5151 load_8bit = False ,
5252 load_4bit = False ,
53- torch_dtype = self .torch_dtype ,
5453 device = 'cpu' ,
55- config = self .llava_llama_config ,
54+ torch_dtype = self .torch_dtype ,
55+ config = self .llava_config ,
5656 )
57+
5758 # llava-lht forward not support "cache_position"
5859 ori_forward = self .vlm_model .forward
5960
@@ -62,6 +63,7 @@ def safe_forward(*args, **kwargs):
6263 kwargs .pop ('cache_position' , None )
6364 return ori_forward (* args , ** kwargs )
6465 self .vlm_model .forward = safe_forward
66+
6567 # llava-lht generate use "inputs" instead of "input_ids"
6668 ori_generate = self .vlm_model .generate
6769
@@ -190,7 +192,7 @@ def __init__(
190192 conv_template = 'vicuna_v1' ,
191193 use_cache : bool = False ,
192194 tie_weights : bool = True ,
193- truncate_context = False , # set it False for LLaVA-1.6
195+ truncate_context = False , # set it False for LLaVA-1.6 no matter truncate
194196 customized_config = None , # ends in json
195197 ** kwargs ,
196198 ) -> None :
@@ -221,6 +223,7 @@ def __init__(
221223 if 'use_flash_attention_2' in kwargs :
222224 llava_model_args ['use_flash_attention_2' ] = kwargs ['use_flash_attention_2' ]
223225 model_name = model_name if model_name is not None else get_model_name_from_path (pretrained )
226+
224227 self ._model = llmc_model .cuda ()
225228 self ._config = self ._model .config
226229 self ._tokenizer = AutoTokenizer .from_pretrained (pretrained , use_fast = False )
@@ -240,6 +243,7 @@ def __init__(
240243 self ._max_length = self ._config .max_sequence_length
241244 else :
242245 self ._max_length = 2048
246+
243247 self .model .eval ()
244248 if tie_weights :
245249 self .model .tie_weights ()
0 commit comments