diff --git a/examples/models/llava/export_llava.py b/examples/models/llava/export_llava.py index e0580aa859a..d95bd7fb054 100644 --- a/examples/models/llava/export_llava.py +++ b/examples/models/llava/export_llava.py @@ -186,7 +186,7 @@ def quant_embedding(model): packed=False, ).quantized_model() - quantized_token_embed = quant_embedding(llava.model_.language_model.model) + quantized_token_embed = quant_embedding(llava.model_.model.language_model) token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len) dynamic_shapes = [{1: token_dim_1}] with torch.no_grad(): diff --git a/examples/models/llava/model.py b/examples/models/llava/model.py index 1050fbdfae1..3973d756e9c 100644 --- a/examples/models/llava/model.py +++ b/examples/models/llava/model.py @@ -31,7 +31,6 @@ from transformers import ( AutoProcessor, CLIPImageProcessor, - LlamaForCausalLM, LlavaForConditionalGeneration, ) @@ -104,19 +103,19 @@ def __init__( def _translate_state_dict_for_text_model(self) -> Dict[str, Any]: # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`. - state_dict = self.model_.language_model.state_dict() + state_dict = self.model_.state_dict() key_map = { # fmt: off - r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", - r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", - r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", - r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", - r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", - r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", - r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", - r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", - r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", - r"model.norm.": r"norm.", + r"model.language_model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.", + r"model.language_model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.", + r"model.language_model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.", + r"model.language_model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.", + r"model.language_model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.", + r"model.language_model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.", + r"model.language_model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.", + r"model.language_model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.", + r"model.language_model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.", + r"model.language_model.norm.": r"norm.", # r"model.embed_tokens.": r"tok_embeddings.", # load separately r"lm_head.": r"output.", # fmt: on @@ -157,7 +156,7 @@ def get_model(self): def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor: # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`. - return self.model_.language_model.model.embed_tokens(tokens) + return self.model_.language_model.embed_tokens(tokens) def encode_images(self, images: torch.Tensor) -> torch.Tensor: # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`. @@ -289,13 +288,8 @@ def prefill_ref( """Avoiding the torch.where() call to find placeholder and insert image embedding. Taking 3 inputs instead.""" embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image) # pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`. - return LlamaForCausalLM.forward( - # pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`. - self.model_.language_model, - inputs_embeds=embeds, - return_dict=False, - use_cache=False, - output_hidden_states=False, + return self.model_.forward( + inputs_embeds=embeds, use_cache=False, return_dict=False, logits_to_keep=1 ) def forward( @@ -309,25 +303,42 @@ class LlavaModel(EagerModelBase): def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768): self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op self.max_seq_len = max_seq_len - self.processor = AutoProcessor.from_pretrained( - "llava-hf/llava-1.5-7b-hf", - revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2 - ) - self.tokenizer = self.processor.tokenizer - self.image_processor = self.processor.image_processor self.model = LlavaForConditionalGeneration.from_pretrained( "llava-hf/llava-1.5-7b-hf", device_map="cpu", revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2 ) - self.image = Image.open( - requests.get( - "https://llava-vl.github.io/static/images/view.jpg", stream=True - ).raw + self.processor = AutoProcessor.from_pretrained( + "llava-hf/llava-1.5-7b-hf", + revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2 + patch_size=self.model.vision_tower.config.patch_size, # Required after transformers >= 4.52.0 ) - self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: -What are the things I should be cautious about when I visit here? ASSISTANT:""" + self.tokenizer = self.processor.tokenizer + self.image_processor = self.processor.image_processor + self.image_url = "https://llava-vl.github.io/static/images/view.jpg" + self.image = Image.open(requests.get(self.image_url, stream=True).raw) + self.system_prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. """ + current_template = self.processor.chat_template + # Prepend the system prompt to the template + new_template = self.system_prompt + current_template + + # Set the modified template back to the tokenizer + self.processor.chat_template = new_template + self.model_name = "llava-1.5-7b-hf" + + self.conversation = [ + { + "role": "user", + "content": [ + {"type": "image", "url": self.image_url}, + { + "type": "text", + "text": "What are the things I should be cautious about when I visit here?", + }, + ], + }, + ] # set input to None and initialize them lazily self.input = None self.resized_image = None @@ -358,11 +369,18 @@ def get_inputs_for_prefill(self): """Returns prompts as well as image.""" if self.input: return self.input - self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu() + inputs = self.processor.apply_chat_template( + self.conversation, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + self.input_ids = inputs["input_ids"] index = torch.where(self.input_ids == self.model.config.image_token_index)[1] - self.prompt_before_image = self.input_ids[:, :index] + self.prompt_before_image = self.input_ids[:, : index[0]] # print(prompt_before_image.shape) - self.prompt_after_image = self.input_ids[:, index + 1 :] + self.prompt_after_image = self.input_ids[:, index[-1] + 1 :] # print(prompt_after_image.shape) self.input = ( self.prompt_before_image, diff --git a/examples/models/llava/test/test_llava.py b/examples/models/llava/test/test_llava.py index 36381b27124..05cfd5b1497 100644 --- a/examples/models/llava/test/test_llava.py +++ b/examples/models/llava/test/test_llava.py @@ -41,8 +41,9 @@ def test_prefill_logits(self): # The reference implementation in HF genetates the full logits. Get the last one. prefill_logits_ref = self.llava.prefill_ref( self.prompt_before_image, self.resized, self.prompt_after_image - )[0][:, -1, :] - self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2)) + )[0] + + torch.testing.assert_close(prefill_logits, prefill_logits_ref.squeeze(0)) def test_generated_output(self): # source of truth, using HF llava diff --git a/requirements-examples.txt b/requirements-examples.txt index 83f3d6bac4c..75785b56975 100644 --- a/requirements-examples.txt +++ b/requirements-examples.txt @@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now timm == 1.0.7 torchsr == 1.0.4 torchtune >= 0.6.1 -transformers ==4.47.1 +transformers >= 4.52.1