Skip to content

LlavaNextForConditionalGeneration.forward() got an unexpected keyword argument 'token_idx'Β #1708

@DavidAbrahamyan

Description

@DavidAbrahamyan

I am trying to do an inference using Llava Next
here is my code:

import habana_frameworks.torch as ht
import habana_frameworks.torch.core as htcore

from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi

from transformers import LlavaNextProcessor, AutoProcessor, AutoConfig, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests
import os

from optimum.habana.transformers.models.llava_next import GaudiLlavaNextForConditionalGeneration
adapt_transformers_to_gaudi()
device = torch.device("hpu")
args_model_name_or_path = "/workspace/models/model_llava_v1_6_vicuna_7b"
model_type = AutoConfig.from_pretrained(args_model_name_or_path).model_type

print("Loading the processor")
processor = AutoProcessor.from_pretrained(args_model_name_or_path)

print("Loading the model")
model = LlavaNextForConditionalGeneration.from_pretrained(args_model_name_or_path,
    torch_dtype=torch.bfloat16,
    ) 
model.to("hpu")

print("hpu graph")

# prepare image and text prompt, using the appropriate prompt template
url = "https://github.com/haotian-liu/LLaVA/blob/1a91fc274d7c35a9b50b3cb29c4247ae5837ce39/images/llava_v1_5_radar.jpg?raw=true"
image = Image.open(requests.get(url, stream=True).raw)

conversation = [
    {

      "role": "user",
      "content": [
          {"type": "text", "text": "What is shown in this image?"},
          {"type": "image"},
        ],
    },
]
print("preparing prompt")
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

print("Prompt goes through processor")
model_dtype = torch.bfloat16
inputs = processor(images=image, text=prompt, return_tensors="pt").to("hpu", model_dtype)

# autoregressively complete prompt
print("generating output")
output = model.generate(**inputs, max_new_tokens=100)

print("printing the final result")
print(processor.decode(output[0], skip_special_tokens=True))

While running this, I get the following error:
TypeError: LlavaNextForConditionalGeneration.forward() got an unexpected keyword argument 'token_idx'

Any idea on what is causing this problem? Thanks in advance

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions