|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | + |
| 16 | +from pathlib import Path |
| 17 | +from typing import Any, Dict, Optional |
| 18 | + |
| 19 | +import torch |
| 20 | +from sentence_transformers import SentenceTransformer |
| 21 | +from sentence_transformers.models import Transformer |
| 22 | +from sentence_transformers.models.Transformer import _save_pretrained_wrapper |
| 23 | +from sentence_transformers.util import import_from_string |
| 24 | +from transformers import MT5Config, T5Config |
| 25 | +from transformers.dynamic_module_utils import get_class_from_dynamic_module |
| 26 | + |
| 27 | +from .modeling_base import IPEXModel |
| 28 | + |
| 29 | + |
| 30 | +class IPEXTransformer(Transformer): |
| 31 | + def __init__(self, *args, **kwargs): |
| 32 | + super().__init__(*args, **kwargs) |
| 33 | + self.backend = "ipex" |
| 34 | + |
| 35 | + def _load_model(self, model_name_or_path, config, cache_dir, backend, **model_args) -> None: |
| 36 | + self._load_ipex_model(model_name_or_path, config, cache_dir, **model_args) |
| 37 | + |
| 38 | + def _load_ipex_model(self, model_name_or_path, config, cache_dir, **model_args) -> None: |
| 39 | + if isinstance(config, T5Config) or isinstance(config, MT5Config): |
| 40 | + raise ValueError("T5 models are not yet supported by the IPEX backend.") |
| 41 | + |
| 42 | + export = model_args.pop("export", None) |
| 43 | + |
| 44 | + if export is None: |
| 45 | + export = not getattr(config, "torchscript", False) |
| 46 | + |
| 47 | + load_path = Path(model_name_or_path) |
| 48 | + is_local = load_path.exists() |
| 49 | + |
| 50 | + self.auto_model = IPEXModel.from_pretrained( |
| 51 | + model_name_or_path, |
| 52 | + config=config, |
| 53 | + cache_dir=cache_dir, |
| 54 | + export=export, |
| 55 | + **model_args, |
| 56 | + ) |
| 57 | + |
| 58 | + # Wrap the save_pretrained method to save the model in the correct subfolder |
| 59 | + self.auto_model._save_pretrained = _save_pretrained_wrapper(self.auto_model._save_pretrained, "ipex") |
| 60 | + |
| 61 | + # Warn the user to save the model if they haven't already |
| 62 | + if export: |
| 63 | + self._backend_warn_to_save(model_name_or_path, is_local, "IPEX") |
| 64 | + |
| 65 | + |
| 66 | +class IPEXSentenceTransformer(SentenceTransformer): |
| 67 | + def __init__(self, *args, **kwargs): |
| 68 | + super().__init__(*args, **kwargs) |
| 69 | + |
| 70 | + self.backend = "ipex" |
| 71 | + |
| 72 | + def _load_module_class_from_ref( |
| 73 | + self, |
| 74 | + class_ref: str, |
| 75 | + model_name_or_path: str, |
| 76 | + trust_remote_code: bool, |
| 77 | + revision: Optional[str] = None, |
| 78 | + model_kwargs: Optional[Dict[str, Any]] = None, |
| 79 | + ) -> torch.nn.Module: |
| 80 | + if class_ref.startswith("sentence_transformers."): |
| 81 | + if class_ref == "sentence_transformers.models.Transformer": |
| 82 | + class_ref = "optimum.intel.ipex.modeling_sentence_transformers.IPEXTransformer" |
| 83 | + return import_from_string(class_ref) |
| 84 | + |
| 85 | + if trust_remote_code: |
| 86 | + code_revision = model_kwargs.pop("code_revision", None) if model_kwargs else None |
| 87 | + try: |
| 88 | + return get_class_from_dynamic_module( |
| 89 | + class_ref, |
| 90 | + model_name_or_path, |
| 91 | + revision=revision, |
| 92 | + code_revision=code_revision, |
| 93 | + ) |
| 94 | + except OSError: |
| 95 | + # Ignore the error if the file does not exist, and fall back to the default import |
| 96 | + pass |
| 97 | + |
| 98 | + return import_from_string(class_ref) |
0 commit comments