|
18 | 18 | import math |
19 | 19 | import types |
20 | 20 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union |
21 | | -from optimum.exporters.onnx.base import OnnxConfig |
22 | 21 |
|
23 | 22 | import torch |
24 | 23 | import torch.nn.functional as F |
25 | 24 | from transformers import PreTrainedModel, TFPreTrainedModel |
26 | 25 | from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling |
27 | 26 | from transformers.utils import is_tf_available |
28 | 27 |
|
| 28 | +from optimum.exporters.onnx.base import OnnxConfig |
29 | 29 | from optimum.exporters.onnx.model_patcher import DecoderModelPatcher, ModelPatcher, override_arguments |
30 | 30 | from optimum.intel.utils.import_utils import ( |
31 | 31 | _openvino_version, |
@@ -423,9 +423,9 @@ def _llama_gemma_update_causal_mask_legacy(self, attention_mask, input_tensor, c |
423 | 423 | offset = 0 |
424 | 424 | mask_shape = attention_mask.shape |
425 | 425 | mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype |
426 | | - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( |
427 | | - mask_slice |
428 | | - ) |
| 426 | + causal_mask[ |
| 427 | + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] |
| 428 | + ] = mask_slice |
429 | 429 |
|
430 | 430 | if ( |
431 | 431 | self.config._attn_implementation == "sdpa" |
@@ -2060,9 +2060,9 @@ def _dbrx_update_causal_mask_legacy( |
2060 | 2060 | offset = 0 |
2061 | 2061 | mask_shape = attention_mask.shape |
2062 | 2062 | mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype |
2063 | | - causal_mask[: mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3]] = ( |
2064 | | - mask_slice |
2065 | | - ) |
| 2063 | + causal_mask[ |
| 2064 | + : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] |
| 2065 | + ] = mask_slice |
2066 | 2066 |
|
2067 | 2067 | if ( |
2068 | 2068 | self.config._attn_implementation == "sdpa" |
@@ -3386,10 +3386,9 @@ class Qwen2VLLanguageModelPatcher(DecoderModelPatcher): |
3386 | 3386 | def __init__( |
3387 | 3387 | self, |
3388 | 3388 | config: OnnxConfig, |
3389 | | - model: PreTrainedModel | TFPreTrainedModel, |
3390 | | - model_kwargs: Dict[str, Any] | None = None, |
| 3389 | + model: Union[PreTrainedModel, TFPreTrainedModel], |
| 3390 | + model_kwargs: Dict[str, Any] = None, |
3391 | 3391 | ): |
3392 | | - |
3393 | 3392 | model.__orig_forward = model.forward |
3394 | 3393 |
|
3395 | 3394 | def forward_wrap( |
@@ -3426,8 +3425,8 @@ class Qwen2VLVisionEmbMergerPatcher(ModelPatcher): |
3426 | 3425 | def __init__( |
3427 | 3426 | self, |
3428 | 3427 | config: OnnxConfig, |
3429 | | - model: PreTrainedModel | TFPreTrainedModel, |
3430 | | - model_kwargs: Dict[str, Any] | None = None, |
| 3428 | + model: Union[PreTrainedModel, TFPreTrainedModel], |
| 3429 | + model_kwargs: Dict[str, Any] = None, |
3431 | 3430 | ): |
3432 | 3431 | model.__orig_forward = model.forward |
3433 | 3432 |
|
|
0 commit comments