-
Notifications
You must be signed in to change notification settings - Fork 270
Open
Description
I am trying to do an inference using Llava Next
here is my code:
import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore
from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi
from transformers import LlavaNextProcessor, AutoProcessor, AutoConfig, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
import os
from optimum.habana.transformers.models.llava_next import GaudiLlavaNextForConditionalGeneration
adapt_transformers_to_gaudi()
device = torch.device("hpu")
args_model_name_or_path = "/workspace/models/model_llava_v1_6_vicuna_7b"
model_type = AutoConfig.from_pretrained(args_model_name_or_path).model_type
print("Loading the processor")
processor = AutoProcessor.from_pretrained(args_model_name_or_path)
print("Loading the model")
model = LlavaNextForConditionalGeneration.from_pretrained(args_model_name_or_path,
torch_dtype=torch.bfloat16,
)
model.to("hpu")
print("hpu graph")
# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)
conversation = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is shown in this image?"},
{"type": "image"},
],
},
]
print("preparing prompt")
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
print("Prompt goes through processor")
model_dtype = torch.bfloat16
inputs = processor(images=image, text=prompt, return_tensors="pt").to("hpu", model_dtype)
# autoregressively complete prompt
print("generating output")
output = model.generate(**inputs, max_new_tokens=100)
print("printing the final result")
print(processor.decode(output[0], skip_special_tokens=True))
While running this, I get the following error:
TypeError: LlavaNextForConditionalGeneration.forward() got an unexpected keyword argument 'token_idx'
Any idea on what is causing this problem? Thanks in advance
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels