|
33 | 33 | from ..feature_extraction_utils import PreTrainedFeatureExtractor |
34 | 34 | from ..image_processing_utils import BaseImageProcessor |
35 | 35 | from ..modelcard import ModelCard |
36 | | -from ..models.auto.configuration_auto import AutoConfig |
| 36 | +from ..models.auto import AutoConfig, AutoTokenizer |
37 | 37 | from ..processing_utils import ProcessorMixin |
38 | 38 | from ..tokenization_utils import PreTrainedTokenizer |
39 | 39 | from ..utils import ( |
@@ -425,6 +425,62 @@ def get_default_model_and_revision( |
425 | 425 | return default_models[framework] |
426 | 426 |
|
427 | 427 |
|
| 428 | +def load_assistant_model( |
| 429 | + model: "PreTrainedModel", |
| 430 | + assistant_model: Optional[Union[str, "PreTrainedModel"]], |
| 431 | + assistant_tokenizer: Optional[PreTrainedTokenizer], |
| 432 | +) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]: |
| 433 | + """ |
| 434 | + Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`. |
| 435 | +
|
| 436 | + Args: |
| 437 | + model ([`PreTrainedModel`]): |
| 438 | + The main model that will be used by the pipeline to make predictions. |
| 439 | + assistant_model (`str` or [`PreTrainedModel`], *optional*): |
| 440 | + The assistant model that will be used by the pipeline to make predictions. |
| 441 | + assistant_tokenizer ([`PreTrainedTokenizer`], *optional*): |
| 442 | + The assistant tokenizer that will be used by the pipeline to encode data for the model. |
| 443 | +
|
| 444 | + Returns: |
| 445 | + Tuple: The loaded assistant model and (optionally) the loaded tokenizer. |
| 446 | + """ |
| 447 | + if not model.can_generate() or assistant_model is None: |
| 448 | + return None, None |
| 449 | + |
| 450 | + if not isinstance(model, PreTrainedModel): |
| 451 | + raise ValueError( |
| 452 | + "Assisted generation, triggered by the `assistant_model` argument, is only available for " |
| 453 | + "`PreTrainedModel` model instances. For instance, TF or JAX models are not supported." |
| 454 | + ) |
| 455 | + |
| 456 | + # If the model is passed as a string, load the model and the corresponding tokenizer |
| 457 | + if isinstance(assistant_model, str): |
| 458 | + assistant_config = AutoConfig.from_pretrained(assistant_model) |
| 459 | + _, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config) |
| 460 | + loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype) |
| 461 | + loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model) |
| 462 | + else: |
| 463 | + loaded_assistant_model = assistant_model |
| 464 | + loaded_assistant_tokenizer = assistant_tokenizer |
| 465 | + |
| 466 | + # Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant |
| 467 | + # tokenizer |
| 468 | + same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size |
| 469 | + same_special_tokens = all( |
| 470 | + getattr(model.config, token) == getattr(loaded_assistant_model.config, token) |
| 471 | + for token in ("eos_token_id", "pad_token_id", "bos_token_id") |
| 472 | + ) |
| 473 | + if same_vocab_size and same_special_tokens: |
| 474 | + loaded_assistant_tokenizer = None |
| 475 | + elif loaded_assistant_tokenizer is None: |
| 476 | + raise ValueError( |
| 477 | + "The assistant model has a different tokenizer than the main model. You should pass the assistant " |
| 478 | + "tokenizer." |
| 479 | + ) |
| 480 | + |
| 481 | + return loaded_assistant_model, loaded_assistant_tokenizer |
| 482 | + |
| 483 | + |
428 | 484 | class PipelineException(Exception): |
429 | 485 | """ |
430 | 486 | Raised by a [`Pipeline`] when handling __call__. |
@@ -925,8 +981,13 @@ def __init__( |
925 | 981 | ): |
926 | 982 | self.model.to(self.device) |
927 | 983 |
|
928 | | - # If the model can generate, create a local generation config. This is done to avoid side-effects on the model |
929 | | - # as we apply local tweaks to the generation config. |
| 984 | + # If the model can generate: |
| 985 | + # 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local |
| 986 | + # tweaks to the generation config. |
| 987 | + # 2 - load the assistant model if it is passed. |
| 988 | + self.assistant_model, self.assistant_tokenizer = load_assistant_model( |
| 989 | + self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None) |
| 990 | + ) |
930 | 991 | if self.model.can_generate(): |
931 | 992 | self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None |
932 | 993 | self.generation_config = copy.deepcopy(self.model.generation_config) |
|
0 commit comments