|
21 | 21 | from pathlib import Path |
22 | 22 | from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union |
23 | 23 |
|
| 24 | +from packaging.version import Version |
24 | 25 | from transformers.generation import GenerationMixin |
25 | 26 | from transformers.models.speecht5.modeling_speecht5 import SpeechT5HifiGan |
26 | 27 | from transformers.utils import is_tf_available, is_torch_available |
@@ -187,6 +188,27 @@ def export( |
187 | 188 | if "diffusers" in str(model.__class__) and not is_diffusers_available(): |
188 | 189 | raise ImportError("The package `diffusers` is required to export diffusion models to OpenVINO.") |
189 | 190 |
|
| 191 | + min_version = getattr(config, "MIN_TRANSFORMERS_VERSION", None) |
| 192 | + max_version = getattr(config, "MAX_TRANSFORMERS_VERSION", None) |
| 193 | + |
| 194 | + if min_version is not None: |
| 195 | + if isinstance(min_version, Version): |
| 196 | + min_version = min_version.base_version |
| 197 | + if is_transformers_version("<", min_version): |
| 198 | + raise ValueError( |
| 199 | + f"The current version of Transformers does not allow for the export of the model. Minimum required is " |
| 200 | + f"{config.MIN_TRANSFORMERS_VERSION}, got: {_transformers_version}" |
| 201 | + ) |
| 202 | + |
| 203 | + if max_version is not None: |
| 204 | + if isinstance(max_version, Version): |
| 205 | + max_version = max_version.base_version |
| 206 | + if is_transformers_version(">=", max_version): |
| 207 | + raise ValueError( |
| 208 | + f"The current version of Transformers does not allow for the export of the model. Maximum required is " |
| 209 | + f"{config.MAX_TRANSFORMERS_VERSION}, got: {_transformers_version}" |
| 210 | + ) |
| 211 | + |
190 | 212 | if stateful: |
191 | 213 | # This will be checked anyway after the model conversion, but checking it earlier will save time for a user if not suitable version is used |
192 | 214 | stateful = ensure_stateful_is_available() |
@@ -633,7 +655,11 @@ def export_from_model( |
633 | 655 | ensure_export_task_support_stateful(task) or ensure_model_type_support_stateful(model_type) |
634 | 656 | ) |
635 | 657 |
|
636 | | - if stateful and is_encoder_decoder and not getattr(model, "_supports_cache_class", False): |
| 658 | + if ( |
| 659 | + stateful |
| 660 | + and is_encoder_decoder |
| 661 | + and not getattr(model, "_supports_cache_class", is_transformers_version(">=", "4.54")) |
| 662 | + ): |
637 | 663 | stateful = False |
638 | 664 | # TODO: support onnx_config.py in the model repo |
639 | 665 | if custom_architecture and custom_export_configs is None: |
|
0 commit comments