Skip to content

Commit bb1c68a

Browse files
Support transformers 4.47 (#1088)
* test 4.47 * update optimum * patch gemma attn functions * style * force attn model * latest qwen2 vl position_ids formula * latest qwen2 vl position_ids formula * revert
1 parent 753f84d commit bb1c68a

File tree

4 files changed

+26
-23
lines changed

4 files changed

+26
-23
lines changed

optimum/exporters/openvino/__main__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050

5151

52-
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager"}
52+
FORCE_ATTN_MODEL_CLASSES = {"phi3-v": "eager", "gemma2": "sdpa"}
5353

5454
if TYPE_CHECKING:
5555
from optimum.intel.openvino.configuration import OVConfig

optimum/exporters/openvino/model_patcher.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,26 +2827,6 @@ def patched_forward(*args, **kwargs):
28272827

28282828
self.patched_forward = patched_forward
28292829

2830-
def __enter__(self):
2831-
super().__enter__()
2832-
if is_transformers_version(">=", "4.45.0"):
2833-
from transformers.models.gemma2.modeling_gemma2 import GEMMA2_ATTENTION_CLASSES
2834-
2835-
sdpa_attn = GEMMA2_ATTENTION_CLASSES["sdpa"]
2836-
eager_attn = GEMMA2_ATTENTION_CLASSES["eager"]
2837-
2838-
for layer in self._model.model.layers:
2839-
if isinstance(layer.self_attn, eager_attn):
2840-
layer.self_attn._orig_forward = layer.self_attn.forward
2841-
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
2842-
2843-
def __exit__(self, exc_type, exc_value, traceback):
2844-
super().__exit__(exc_type, exc_value, traceback)
2845-
if is_transformers_version(">=", "4.45.0"):
2846-
for layer in self._model.model.layers:
2847-
if hasattr(layer.self_attn, "_orig_forward"):
2848-
layer.self_attn.forward = layer.self_attn._orig_forward
2849-
28502830

28512831
def _decilm_attn_forward(
28522832
self,

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
if TYPE_CHECKING:
56-
from PIL import Image
56+
from PIL.Image import Image
5757

5858

5959
logger = logging.getLogger(__name__)
@@ -2100,6 +2100,8 @@ def __init__(
21002100
quantization_config=quantization_config,
21012101
**kwargs,
21022102
)
2103+
self.rope_deltas = None # cache rope_deltas here
2104+
21032105
if is_transformers_version(">=", "4.45.0"):
21042106
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
21052107
Qwen2VLForConditionalGeneration,
@@ -2197,6 +2199,7 @@ def get_multimodal_embeddings(
21972199
pixel_values_videos=None,
21982200
image_grid_thw=None,
21992201
video_grid_thw=None,
2202+
cache_position=None,
22002203
**kwargs,
22012204
):
22022205
inputs_embeds = torch.from_numpy(self.get_text_embeddings(input_ids))
@@ -2209,6 +2212,26 @@ def get_multimodal_embeddings(
22092212
video_embeds = torch.from_numpy(self.get_vision_embeddings(pixel_values_videos, video_grid_thw))
22102213
video_mask = input_ids == self.config.video_token_id
22112214
inputs_embeds[video_mask] = video_embeds
2215+
2216+
# if we get 4D attention mask we cannot calculate rope deltas anymore.
2217+
if position_ids is None and input_ids is not None and (attention_mask is None or attention_mask.ndim == 2):
2218+
# calculate RoPE index once per generation in the pre-fill stage only
2219+
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
2220+
position_ids, rope_deltas = self.get_rope_index(
2221+
input_ids, image_grid_thw, video_grid_thw, attention_mask
2222+
)
2223+
self.rope_deltas = rope_deltas
2224+
# then use the prev pre-calculated rope-deltas to get the correct position ids
2225+
else:
2226+
batch_size, seq_length, _ = inputs_embeds.shape
2227+
delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2228+
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
2229+
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2230+
if cache_position is not None: # otherwise `deltas` is an int `0`
2231+
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
2232+
position_ids = position_ids.add(delta)
2233+
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
2234+
22122235
return inputs_embeds, attention_mask, position_ids
22132236

22142237
def forward(

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
INSTALL_REQUIRE = [
3030
"torch>=1.11",
3131
"optimum@git+https://github.com/huggingface/optimum.git",
32-
"transformers>=4.36,<4.47",
32+
"transformers>=4.36,<4.48",
3333
"datasets>=1.4.0",
3434
"sentencepiece",
3535
"setuptools",

0 commit comments

Comments
 (0)