Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/models/llava/export_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def quant_embedding(model):
packed=False,
).quantized_model()

quantized_token_embed = quant_embedding(llava.model_.language_model.model)
quantized_token_embed = quant_embedding(llava.model_.model.language_model)
token_dim_1 = Dim("token_dim_1", min=2, max=llava.text_model_args.max_seq_len)
dynamic_shapes = [{1: token_dim_1}]
with torch.no_grad():
Expand Down
88 changes: 53 additions & 35 deletions examples/models/llava/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from transformers import (
AutoProcessor,
CLIPImageProcessor,
LlamaForCausalLM,
LlavaForConditionalGeneration,
)

Expand Down Expand Up @@ -104,19 +103,19 @@ def __init__(

def _translate_state_dict_for_text_model(self) -> Dict[str, Any]:
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
state_dict = self.model_.language_model.state_dict()
state_dict = self.model_.state_dict()
key_map = {
# fmt: off
r"model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
r"model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
r"model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
r"model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
r"model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
r"model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
r"model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
r"model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
r"model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
r"model.norm.": r"norm.",
r"model.language_model.layers.([0-9]+).self_attn.q_proj.": r"layers.\1.attention.wq.",
r"model.language_model.layers.([0-9]+).self_attn.k_proj.": r"layers.\1.attention.wk.",
r"model.language_model.layers.([0-9]+).self_attn.v_proj.": r"layers.\1.attention.wv.",
r"model.language_model.layers.([0-9]+).self_attn.o_proj.": r"layers.\1.attention.wo.",
r"model.language_model.layers.([0-9]+).input_layernorm.": r"layers.\1.attention_norm.",
r"model.language_model.layers.([0-9]+).mlp.gate_proj.": r"layers.\1.feed_forward.w1.",
r"model.language_model.layers.([0-9]+).mlp.down_proj.": r"layers.\1.feed_forward.w2.",
r"model.language_model.layers.([0-9]+).mlp.up_proj.": r"layers.\1.feed_forward.w3.",
r"model.language_model.layers.([0-9]+).post_attention_layernorm.": r"layers.\1.ffn_norm.",
r"model.language_model.norm.": r"norm.",
# r"model.embed_tokens.": r"tok_embeddings.", # load separately
r"lm_head.": r"output.",
# fmt: on
Expand Down Expand Up @@ -157,7 +156,7 @@ def get_model(self):

def embed_tokens(self, tokens: torch.Tensor) -> torch.Tensor:
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
return self.model_.language_model.model.embed_tokens(tokens)
return self.model_.language_model.embed_tokens(tokens)

def encode_images(self, images: torch.Tensor) -> torch.Tensor:
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `dtype`.
Expand Down Expand Up @@ -289,13 +288,8 @@ def prefill_ref(
"""Avoiding the torch.where() call to find <image> placeholder and insert image embedding. Taking 3 inputs instead."""
embeds = self.prefill_embedding(prompt_before_image, images, prompt_after_image)
# pyre-ignore: Undefined attribute [16]: Module `transformers` has no attribute `LlamaForCausalLM`.
return LlamaForCausalLM.forward(
# pyre-ignore: Undefined attribute [16]: `transformers.utils.dummy_pt_objects.LlavaForConditionalGeneration` has no attribute `language_model`.
self.model_.language_model,
inputs_embeds=embeds,
return_dict=False,
use_cache=False,
output_hidden_states=False,
return self.model_.forward(
inputs_embeds=embeds, use_cache=False, return_dict=False, logits_to_keep=1
)

def forward(
Expand All @@ -309,25 +303,42 @@ class LlavaModel(EagerModelBase):
def __init__(self, use_sdpa_with_kv_cache_op=True, max_seq_len=768):
self.use_sdpa_with_kv_cache_op = use_sdpa_with_kv_cache_op
self.max_seq_len = max_seq_len
self.processor = AutoProcessor.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
)
self.tokenizer = self.processor.tokenizer
self.image_processor = self.processor.image_processor
self.model = LlavaForConditionalGeneration.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
device_map="cpu",
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
)
self.image = Image.open(
requests.get(
"https://llava-vl.github.io/static/images/view.jpg", stream=True
).raw
self.processor = AutoProcessor.from_pretrained(
"llava-hf/llava-1.5-7b-hf",
revision="a272c74b2481d8aff3aa6fc2c4bf891fe57334fb", # Need this for transformers >= 4.44.2
patch_size=self.model.vision_tower.config.patch_size, # Required after transformers >= 4.52.0
)
self.prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. USER: <image>
What are the things I should be cautious about when I visit here? ASSISTANT:"""
self.tokenizer = self.processor.tokenizer
self.image_processor = self.processor.image_processor
self.image_url = "https://llava-vl.github.io/static/images/view.jpg"
self.image = Image.open(requests.get(self.image_url, stream=True).raw)
self.system_prompt = """A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions. """
current_template = self.processor.chat_template
# Prepend the system prompt to the template
new_template = self.system_prompt + current_template

# Set the modified template back to the tokenizer
self.processor.chat_template = new_template

self.model_name = "llava-1.5-7b-hf"

self.conversation = [
{
"role": "user",
"content": [
{"type": "image", "url": self.image_url},
{
"type": "text",
"text": "What are the things I should be cautious about when I visit here?",
},
],
},
]
# set input to None and initialize them lazily
self.input = None
self.resized_image = None
Expand Down Expand Up @@ -358,11 +369,18 @@ def get_inputs_for_prefill(self):
"""Returns prompts as well as image."""
if self.input:
return self.input
self.input_ids = self.tokenizer.encode(self.prompt, return_tensors="pt").cpu()
inputs = self.processor.apply_chat_template(
self.conversation,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
)
self.input_ids = inputs["input_ids"]
index = torch.where(self.input_ids == self.model.config.image_token_index)[1]
self.prompt_before_image = self.input_ids[:, :index]
self.prompt_before_image = self.input_ids[:, : index[0]]
# print(prompt_before_image.shape)
self.prompt_after_image = self.input_ids[:, index + 1 :]
self.prompt_after_image = self.input_ids[:, index[-1] + 1 :]
# print(prompt_after_image.shape)
self.input = (
self.prompt_before_image,
Expand Down
5 changes: 3 additions & 2 deletions examples/models/llava/test/test_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ def test_prefill_logits(self):
# The reference implementation in HF genetates the full logits. Get the last one.
prefill_logits_ref = self.llava.prefill_ref(
self.prompt_before_image, self.resized, self.prompt_after_image
)[0][:, -1, :]
self.assertTrue(torch.allclose(prefill_logits, prefill_logits_ref, atol=3e-2))
)[0]

torch.testing.assert_close(prefill_logits, prefill_logits_ref.squeeze(0))

def test_generated_output(self):
# source of truth, using HF llava
Expand Down
2 changes: 1 addition & 1 deletion requirements-examples.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ datasets == 3.6.0 # 4.0.0 deprecates trust_remote_code and load scripts. For now
timm == 1.0.7
torchsr == 1.0.4
torchtune >= 0.6.1
transformers ==4.47.1
transformers >= 4.52.1
Loading