Skip to content

Commit 409c89a

Browse files
echarlaixmvafin
authored andcommitted
Fix auto_model_class for OVModelForVisualCausalLM (#1391)
* fix auto_model_class for OVModelForVisualCausalLM * fix * fix style
1 parent 91c3f5b commit 409c89a

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from transformers import (
2020
AutoConfig,
2121
AutoImageProcessor,
22-
AutoModelForCausalLM,
23-
AutoModelForVision2Seq,
2422
GenerationConfig,
2523
GenerationMixin,
2624
PretrainedConfig,
@@ -44,15 +42,14 @@
4442
)
4543

4644

47-
try:
48-
from transformers import LlavaForConditionalGeneration
49-
except ImportError:
50-
LlavaForConditionalGeneration = None
45+
if is_transformers_version(">=", "4.46.0"):
46+
from transformers import AutoModelForImageTextToText
5147

52-
try:
53-
from transformers import LlavaNextForConditionalGeneration
54-
except ImportError:
55-
LlavaNextForConditionalGeneration = None
48+
transformers_auto_class = AutoModelForImageTextToText
49+
else:
50+
from transformers import AutoModelForVision2Seq
51+
52+
transformers_auto_class = AutoModelForVision2Seq
5653

5754

5855
if TYPE_CHECKING:
@@ -346,7 +343,7 @@ def forward(self, audio_feature, audio_mask):
346343
class OVModelForVisualCausalLM(OVBaseModel, GenerationMixin):
347344
export_feature = "image-text-to-text"
348345
additional_parts = []
349-
auto_model_class = AutoModelForCausalLM
346+
auto_model_class = transformers_auto_class
350347

351348
def __init__(
352349
self,
@@ -412,10 +409,7 @@ def __init__(
412409

413410
# Avoid warnings when creating a transformers pipeline
414411
AutoConfig.register(self.base_model_prefix, AutoConfig)
415-
try:
416-
self.auto_model_class.register(AutoConfig, self.__class__)
417-
except AttributeError:
418-
pass
412+
self.auto_model_class.register(AutoConfig, self.__class__)
419413

420414
def clear_requests(self):
421415
if self._compile_only:
@@ -931,8 +925,6 @@ def preprocess_inputs(
931925

932926

933927
class _OVLlavaForCausalLM(OVModelForVisualCausalLM):
934-
auto_model_class = LlavaForConditionalGeneration
935-
936928
def __init__(
937929
self,
938930
language_model: ov.Model,
@@ -1137,8 +1129,6 @@ def preprocess_inputs(
11371129

11381130

11391131
class _OVLlavaNextForCausalLM(_OVLlavaForCausalLM):
1140-
auto_model_class = LlavaNextForConditionalGeneration
1141-
11421132
# Adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llava_next/modeling_llava_next.py#L655
11431133
def pack_image_features(self, image_features, image_sizes, image_newline=None):
11441134
from transformers.models.llava_next.modeling_llava_next import get_anyres_image_grid_shape, unpad_image
@@ -1433,7 +1423,6 @@ def get_text_embeddings(self, input_ids, **kwargs):
14331423

14341424
class _OVLlavaNextVideoForCausalLM(_OVLlavaNextForCausalLM):
14351425
additional_parts = ["vision_resampler", "multi_modal_projector"]
1436-
auto_model_class = AutoModelForVision2Seq
14371426

14381427
def get_vision_embeddings(self, pixel_values, input_ids=None, **kwargs):
14391428
if input_ids is not None and input_ids.shape[1] == 1:

0 commit comments

Comments
 (0)