Skip to content

Commit f08e8b7

Browse files
authored
add support of nanollava model (#969)
1 parent 54a9727 commit f08e8b7

File tree

6 files changed

+200
-13
lines changed

6 files changed

+200
-13
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
1818

1919
from packaging import version
20-
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
20+
from transformers import AutoConfig, PretrainedConfig, PreTrainedModel, TFPreTrainedModel
2121
from transformers.utils import is_tf_available
2222

2323
from optimum.exporters.onnx.config import OnnxConfig, TextDecoderOnnxConfig, TextDecoderWithPositionIdsOnnxConfig
@@ -75,6 +75,7 @@
7575
JaisModelPatcher,
7676
LlamaModelPatcher,
7777
LlavaImageEmbeddingModelPatcher,
78+
LlavaQwen2ImageEmbeddingsModelPatcher,
7879
MiniCPMVImageEmbeddingsModelPatcher,
7980
MiniCPMVResamplerModelPatcher,
8081
MistralModelPatcher,
@@ -1579,6 +1580,165 @@ def patch_model_for_export(
15791580
return InternVLChatImageEmbeddingModelPatcher(self, model, model_kwargs)
15801581

15811582

1583+
@register_in_tasks_manager(
1584+
"llava-qwen2", *["image-text-to-text", "text-generation", "text-generation-with-past"], library_name="transformers"
1585+
)
1586+
class LlavaQwen2OpenVINOConfig(OnnxConfig):
1587+
SUPPORTS_PAST = True
1588+
MIN_TRANSFORMERS_VERSION = version.parse("4.40.0")
1589+
SUPPORTED_BEHAVIORS = [model_type.value for model_type in LlavaConfigBehavior]
1590+
NORMALIZED_CONFIG_CLASS = NormalizedVisionConfig
1591+
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator,)
1592+
1593+
def __init__(
1594+
self,
1595+
config: "PretrainedConfig",
1596+
task: str = "feature-extraction",
1597+
int_dtype: str = "int64",
1598+
float_dtype: str = "fp32",
1599+
behavior: LlavaConfigBehavior = LlavaConfigBehavior.VISION_EMBEDDINGS,
1600+
preprocessors: Optional[List[Any]] = None,
1601+
use_past: bool = False,
1602+
):
1603+
self._behavior = behavior
1604+
self._orig_config = config
1605+
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1606+
config = AutoConfig.from_pretrained(config.mm_vision_tower, trust_remote_code=True)
1607+
if hasattr(config, "vision_config"):
1608+
config = config.vision_config
1609+
super().__init__(
1610+
config=config,
1611+
task=task,
1612+
int_dtype=int_dtype,
1613+
float_dtype=float_dtype,
1614+
preprocessors=preprocessors,
1615+
)
1616+
1617+
@property
1618+
def inputs(self) -> Dict[str, Dict[int, str]]:
1619+
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1620+
return {}
1621+
return {"pixel_values": {0: "batch_size", 2: "height", 3: "width"}}
1622+
1623+
@property
1624+
def outputs(self) -> Dict[str, Dict[int, str]]:
1625+
if not self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1626+
return {}
1627+
return {"last_hidden_state": {0: "batch_size"}}
1628+
1629+
def get_model_for_behavior(self, model, behavior: Union[str, LlavaConfigBehavior]):
1630+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
1631+
behavior = LlavaConfigBehavior(behavior)
1632+
1633+
if behavior == LlavaConfigBehavior.LANGUAGE:
1634+
model.forward = super(type(model), model).forward
1635+
return model
1636+
1637+
if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1638+
return model
1639+
1640+
if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
1641+
text_embedding = model.model.embed_tokens
1642+
text_embedding.config = model.model.config
1643+
return text_embedding
1644+
1645+
def with_behavior(
1646+
self,
1647+
behavior: Union[str, LlavaConfigBehavior],
1648+
):
1649+
"""
1650+
Creates a config for different behaviour.
1651+
Args:
1652+
behavior ([`ConfigBehavior`]):
1653+
The behavior to use for the new instance.
1654+
"""
1655+
if isinstance(behavior, str) and not isinstance(behavior, LlavaConfigBehavior):
1656+
behavior = LlavaConfigBehavior(behavior)
1657+
1658+
if behavior == LlavaConfigBehavior.TEXT_EMBEDDINGS:
1659+
model_type = self._orig_config.model_type.replace("llava-", "")
1660+
model_type = model_type.replace("_", "-")
1661+
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
1662+
raise ValueError(
1663+
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
1664+
)
1665+
1666+
if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
1667+
raise ValueError(
1668+
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
1669+
)
1670+
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
1671+
"text-generation-with-past"
1672+
]
1673+
internal_export_config = internal_export_config_class(
1674+
self._orig_config,
1675+
use_past=True,
1676+
use_past_in_inputs=True,
1677+
int_dtype=self.int_dtype,
1678+
float_dtype=self.float_dtype,
1679+
)
1680+
InputEmbedOpenvVINOConfig.NORMALIZED_CONFIG_CLASS = internal_export_config.NORMALIZED_CONFIG_CLASS
1681+
export_config = InputEmbedOpenvVINOConfig(
1682+
self._orig_config,
1683+
task="feature-extraction",
1684+
int_dtype=self.int_dtype,
1685+
float_dtype=self.float_dtype,
1686+
)
1687+
return export_config
1688+
1689+
if behavior == LlavaConfigBehavior.LANGUAGE:
1690+
model_type = self._orig_config.model_type.replace("llava-", "")
1691+
model_type = model_type.replace("_", "-")
1692+
1693+
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE:
1694+
raise ValueError(
1695+
f"Unsupported language model type provided `{model_type}`. Please define custom export config"
1696+
)
1697+
1698+
if "text-generation-with-past" not in TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"]:
1699+
raise ValueError(
1700+
f"Export config for text generation for `{model_type}` is not available. Please define custom export config"
1701+
)
1702+
internal_export_config_class = TasksManager._SUPPORTED_MODEL_TYPE[model_type]["openvino"][
1703+
"text-generation-with-past"
1704+
]
1705+
internal_export_config = internal_export_config_class(
1706+
self._orig_config,
1707+
use_past=True,
1708+
use_past_in_inputs=True,
1709+
int_dtype=self.int_dtype,
1710+
float_dtype=self.float_dtype,
1711+
)
1712+
export_config = LMInputEmbedsConfigHelper(internal_export_config)
1713+
export_config._normalized_config = internal_export_config._normalized_config
1714+
return export_config
1715+
1716+
if behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1717+
return self.__class__(
1718+
self._orig_config,
1719+
task=self.task,
1720+
int_dtype=self.int_dtype,
1721+
float_dtype=self.float_dtype,
1722+
behavior=behavior,
1723+
preprocessors=self._preprocessors,
1724+
)
1725+
1726+
def patch_model_for_export(
1727+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
1728+
):
1729+
model_kwargs = model_kwargs or {}
1730+
if self._behavior != LlavaConfigBehavior.VISION_EMBEDDINGS:
1731+
return super().patch_model_for_export(model, model_kwargs)
1732+
return LlavaQwen2ImageEmbeddingsModelPatcher(self, model, model_kwargs)
1733+
1734+
def rename_ambiguous_inputs(self, inputs):
1735+
if self._behavior == LlavaConfigBehavior.VISION_EMBEDDINGS:
1736+
model_inputs = {}
1737+
model_inputs["images"] = inputs["pixel_values"]
1738+
return model_inputs
1739+
return super().rename_ambiguous_inputs(inputs)
1740+
1741+
15821742
class PooledProjectionsDummyInputGenerator(DummyInputGenerator):
15831743
SUPPORTED_INPUT_NAMES = ["pooled_projections"]
15841744

optimum/exporters/openvino/model_patcher.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2973,3 +2973,21 @@ def __exit__(self, exc_type, exc_value, traceback):
29732973
if is_torch_version(">=", "2.0.0"):
29742974
for layer in self._model.encoder.layers:
29752975
layer.self_attn.forward = layer.self_attn._orig_forward
2976+
2977+
2978+
class LlavaQwen2ImageEmbeddingsModelPatcher(ModelPatcher):
2979+
def __init__(
2980+
self,
2981+
config: "OnnxConfig",
2982+
model: Union["PreTrainedModel", "TFPreTrainedModel"],
2983+
model_kwargs: Dict[str, Any],
2984+
):
2985+
model.__orig_forward = model.forward
2986+
model.forward = model.encode_images
2987+
super().__init__(config, model, model_kwargs)
2988+
if not self._model.get_vision_tower().is_loaded:
2989+
self._model.get_vision_tower().load_model()
2990+
2991+
def __exit__(self, exc_type, exc_value, traceback):
2992+
super().__exit__(exc_type, exc_value, traceback)
2993+
self._model.forward = self._model.__orig_forward

optimum/exporters/openvino/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,4 +208,4 @@ def get_submodels(model):
208208
return custom_export, fn_get_submodels
209209

210210

211-
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "internvl-chat", "minicpmv"]
211+
MULTI_MODAL_TEXT_GENERATION_MODELS = ["llava", "llava-next", "llava-qwen2", "internvl-chat", "minicpmv"]

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from transformers.modeling_outputs import BaseModelOutputWithPooling
1515

1616
from ...exporters.openvino import main_export
17-
from ...exporters.openvino.stateful import ensure_stateful_is_available
17+
from ...exporters.openvino.stateful import ensure_stateful_is_available, model_has_input_output_name
1818
from .configuration import OVConfig, OVWeightQuantizationConfig
1919
from .modeling_base import OVBaseModel, OVModelPart
2020
from .modeling_decoder import CausalLMOutputWithPast, OVModelForCausalLM
@@ -122,8 +122,8 @@ def prepare_inputs(
122122
else:
123123
position_ids = np.cumsum(attention_mask, axis=1) - 1
124124
position_ids[attention_mask == 0] = 1
125-
if past_key_values:
126-
position_ids = position_ids[:, -input_ids.shape[1] :]
125+
if past_len:
126+
position_ids = position_ids[:, -inputs_embeds.shape[1] :]
127127

128128
inputs["position_ids"] = position_ids
129129

@@ -177,9 +177,11 @@ def __init__(self, model: ov.Model, parent_model: OVBaseModel) -> None:
177177
self.hidden_states_output_names = [
178178
key.get_any_name() for key in self.model.outputs[2:] if "hidden_states" in key.get_any_name()
179179
]
180+
self.input_names = {key.get_any_name(): idx for idx, key in enumerate(self.model.inputs)}
181+
self._main_input = "images" if model_has_input_output_name(self.model, "images") else "pixel_values"
180182

181183
def forward(self, pixel_values, **kwargs):
182-
inputs = {"pixel_values": pixel_values}
184+
inputs = {self._main_input: pixel_values}
183185
if len(self.input_names) > 1:
184186
for name in self.input_names:
185187
if name in kwargs:
@@ -568,16 +570,19 @@ def half(self):
568570
def forward(
569571
self,
570572
input_ids,
571-
pixel_values,
573+
pixel_values=None,
572574
past_key_values=None,
573575
inputs_embeds=None,
574576
image_sizes=None,
575577
attention_mask=None,
576578
position_ids=None,
577579
image_bound=None,
578580
tgt_sizes=None,
581+
images=None,
579582
**kwargs,
580583
):
584+
if pixel_values is None and images is not None:
585+
pixel_values = images
581586
inputs_embeds, attention_mask, position_ids = self.get_multimodal_embeddings(
582587
input_ids,
583588
pixel_values,
@@ -629,6 +634,7 @@ def get_multimodal_embeddings(
629634
)
630635
return inputs_embeds, attention_mask, position_ids
631636

637+
# Adopted from https://github.com/huggingface/transformers/blob/v4.44.2/src/transformers/models/llava/modeling_llava.py#L521
632638
def prepare_inputs_for_generation(
633639
self,
634640
input_ids,
@@ -646,14 +652,15 @@ def prepare_inputs_for_generation(
646652
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
647653
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
648654
# input)
649-
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
650-
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
655+
if attention_mask is not None and past_length + 1 > input_ids.shape[1]:
656+
input_discount = max(attention_mask.shape[1] - past_length, 1)
657+
input_ids = input_ids[:, -input_discount:]
651658
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
652659
# input_ids based on the past_length.llava
653660
elif past_length < input_ids.shape[1]:
654661
input_ids = input_ids[:, past_length:]
655662
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
656-
elif getattr(self.config, "image_token_index", None) in input_ids:
663+
elif getattr(self.config, "image_token_index", -1) in input_ids:
657664
input_ids = input_ids[:, input_ids.shape[1] - 1 :]
658665

659666
position_ids = kwargs.get("position_ids", None)
@@ -679,6 +686,7 @@ def prepare_inputs_for_generation(
679686
"image_sizes": image_sizes,
680687
"image_bound": kwargs.get("image_bound"),
681688
"tgt_sizes": kwargs.get("tgt_sizes"),
689+
"images": kwargs.get("images"),
682690
}
683691
)
684692
return model_inputs
@@ -1546,4 +1554,5 @@ def get_multimodal_embeddings(
15461554
"llava_next": _OVLlavaNextForCausalLM,
15471555
"internvl_chat": _OvInternVLForCausalLM,
15481556
"minicpmv": _OVMiniCPMVForCausalLM,
1557+
"llava-qwen2": _OVNanoLlavaForCausalLM,
15491558
}

tests/openvino/test_modeling.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1879,12 +1879,11 @@ def test_compare_with_and_without_past_key_values(self):
18791879
class OVModelForVisualCausalLMIntegrationTest(unittest.TestCase):
18801880
SUPPORTED_ARCHITECTURES = ["llava"]
18811881

1882-
REMOTE_CODE_MODELS = ["minicpmv"]
1883-
18841882
if is_transformers_version(">=", "4.40.0"):
1885-
SUPPORTED_ARCHITECTURES += ["llava_next"]
1883+
SUPPORTED_ARCHITECTURES += ["llava_next", "nanollava"]
18861884
if is_transformers_version(">=", "4.45.0"):
18871885
SUPPORTED_ARCHITECTURES += ["minicpmv"]
1886+
REMOTE_CODE_MODELS = ["minicpmv", "nanollava"]
18881887
TASK = "image-text-to-text"
18891888

18901889
IMAGE = Image.open(

tests/openvino/utils_tests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@
9595
"mpt": "hf-internal-testing/tiny-random-MptForCausalLM",
9696
"mpnet": "hf-internal-testing/tiny-random-MPNetModel",
9797
"mt5": "stas/mt5-tiny-random",
98+
"nanollava": "katuni4ka/tiny-random-nanollava",
9899
"nystromformer": "hf-internal-testing/tiny-random-NystromformerModel",
99100
"olmo": "katuni4ka/tiny-random-olmo-hf",
100101
"orion": "katuni4ka/tiny-random-orion",

0 commit comments

Comments
 (0)