3131from transformers import (
3232 AutoProcessor ,
3333 CLIPImageProcessor ,
34- LlamaForCausalLM ,
3534 LlavaForConditionalGeneration ,
3635)
3736
@@ -104,19 +103,19 @@ def __init__(
104103
105104 def _translate_state_dict_for_text_model (self ) -> Dict [str , Any ]:
106105 # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
107- state_dict = self .model_ .language_model . state_dict ()
106+ state_dict = self .model_ .state_dict ()
108107 key_map = {
109108 # fmt: off
110- r"model.layers.([0-9]+).self_attn.q_proj." : r"layers.\1.attention.wq." ,
111- r"model.layers.([0-9]+).self_attn.k_proj." : r"layers.\1.attention.wk." ,
112- r"model.layers.([0-9]+).self_attn.v_proj." : r"layers.\1.attention.wv." ,
113- r"model.layers.([0-9]+).self_attn.o_proj." : r"layers.\1.attention.wo." ,
114- r"model.layers.([0-9]+).input_layernorm." : r"layers.\1.attention_norm." ,
115- r"model.layers.([0-9]+).mlp.gate_proj." : r"layers.\1.feed_forward.w1." ,
116- r"model.layers.([0-9]+).mlp.down_proj." : r"layers.\1.feed_forward.w2." ,
117- r"model.layers.([0-9]+).mlp.up_proj." : r"layers.\1.feed_forward.w3." ,
118- r"model.layers.([0-9]+).post_attention_layernorm." : r"layers.\1.ffn_norm." ,
119- r"model.norm." : r"norm." ,
109+ r"model.language_model. layers.([0-9]+).self_attn.q_proj." : r"layers.\1.attention.wq." ,
110+ r"model.language_model. layers.([0-9]+).self_attn.k_proj." : r"layers.\1.attention.wk." ,
111+ r"model.language_model. layers.([0-9]+).self_attn.v_proj." : r"layers.\1.attention.wv." ,
112+ r"model.language_model. layers.([0-9]+).self_attn.o_proj." : r"layers.\1.attention.wo." ,
113+ r"model.language_model. layers.([0-9]+).input_layernorm." : r"layers.\1.attention_norm." ,
114+ r"model.language_model. layers.([0-9]+).mlp.gate_proj." : r"layers.\1.feed_forward.w1." ,
115+ r"model.language_model. layers.([0-9]+).mlp.down_proj." : r"layers.\1.feed_forward.w2." ,
116+ r"model.language_model. layers.([0-9]+).mlp.up_proj." : r"layers.\1.feed_forward.w3." ,
117+ r"model.language_model. layers.([0-9]+).post_attention_layernorm." : r"layers.\1.ffn_norm." ,
118+ r"model.language_model. norm." : r"norm." ,
120119 # r"model.embed_tokens.": r"tok_embeddings.", # load separately
121120 r"lm_head." : r"output." ,
122121 # fmt: on
@@ -157,7 +156,7 @@ def get_model(self):
157156
158157 def embed_tokens (self , tokens : torch .Tensor ) -> torch .Tensor :
159158 # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
160- return self .model_ .language_model .model . embed_tokens (tokens )
159+ return self .model_ .language_model .embed_tokens (tokens )
161160
162161 def encode_images (self , images : torch .Tensor ) -> torch .Tensor :
163162 # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
@@ -289,13 +288,8 @@ def prefill_ref(
289288 """Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
290289 embeds = self .prefill_embedding (prompt_before_image , images , prompt_after_image )
291290 # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
292- return LlamaForCausalLM .forward (
293- # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
294- self .model_ .language_model ,
295- inputs_embeds = embeds ,
296- return_dict = False ,
297- use_cache = False ,
298- output_hidden_states = False ,
291+ return self .model_ .forward (
292+ inputs_embeds = embeds , use_cache = False , return_dict = False , logits_to_keep = 1
299293 )
300294
301295 def forward (
@@ -309,25 +303,42 @@ class LlavaModel(EagerModelBase):
309303 def __init__ (self , use_sdpa_with_kv_cache_op = True , max_seq_len = 768 ):
310304 self .use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
311305 self .max_seq_len = max_seq_len
312- self .processor = AutoProcessor .from_pretrained (
313- "llava-hf/llava-1.5-7b-hf" ,
314- revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
315- )
316- self .tokenizer = self .processor .tokenizer
317- self .image_processor = self .processor .image_processor
318306 self .model = LlavaForConditionalGeneration .from_pretrained (
319307 "llava-hf/llava-1.5-7b-hf" ,
320308 device_map = "cpu" ,
321309 revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
322310 )
323- self .image = Image . open (
324- requests . get (
325- "https://llava-vl.github.io/static/images/view.jpg" , stream = True
326- ). raw
311+ self .processor = AutoProcessor . from_pretrained (
312+ "llava-hf/llava-1.5-7b-hf" ,
313+ revision = "a272c74b2481d8aff3aa6fc2c4bf891fe57334fb" , # Need this for transformers >= 4.44.2
314+ patch_size = self . model . vision_tower . config . patch_size , # Required after transformers >= 4.52.0
327315 )
328- self .prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
329- What are the things I should be cautious about when I visit here? ASSISTANT:"""
316+ self .tokenizer = self .processor .tokenizer
317+ self .image_processor = self .processor .image_processor
318+ self .image_url = "https://llava-vl.github.io/static/images/view.jpg"
319+ self .image = Image .open (requests .get (self .image_url , stream = True ).raw )
320+ self .system_prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. """
321+ current_template = self .processor .chat_template
322+ # Prepend the system prompt to the template
323+ new_template = self .system_prompt + current_template
324+
325+ # Set the modified template back to the tokenizer
326+ self .processor .chat_template = new_template
327+
330328 self .model_name = "llava-1.5-7b-hf"
329+
330+ self .conversation = [
331+ {
332+ "role" : "user" ,
333+ "content" : [
334+ {"type" : "image" , "url" : self .image_url },
335+ {
336+ "type" : "text" ,
337+ "text" : "What are the things I should be cautious about when I visit here?" ,
338+ },
339+ ],
340+ },
341+ ]
331342 # set input to None and initialize them lazily
332343 self .input = None
333344 self .resized_image = None
@@ -358,11 +369,18 @@ def get_inputs_for_prefill(self):
358369 """Returns prompts as well as image."""
359370 if self .input :
360371 return self .input
361- self .input_ids = self .tokenizer .encode (self .prompt , return_tensors = "pt" ).cpu ()
372+ inputs = self .processor .apply_chat_template (
373+ self .conversation ,
374+ add_generation_prompt = True ,
375+ tokenize = True ,
376+ return_dict = True ,
377+ return_tensors = "pt" ,
378+ )
379+ self .input_ids = inputs ["input_ids" ]
362380 index = torch .where (self .input_ids == self .model .config .image_token_index )[1 ]
363- self .prompt_before_image = self .input_ids [:, :index ]
381+ self .prompt_before_image = self .input_ids [:, : index [ 0 ] ]
364382 # print(prompt_before_image.shape)
365- self .prompt_after_image = self .input_ids [:, index + 1 :]
383+ self .prompt_after_image = self .input_ids [:, index [ - 1 ] + 1 :]
366384 # print(prompt_after_image.shape)
367385 self .input = (
368386 self .prompt_before_image ,
0 commit comments