Skip to content

Commit 76da6ca

Browse files
ganteRocketknight1
andauthored
Pipeline: simple API for assisted generation (#34504)
Co-authored-by: Matt <[email protected]>
1 parent 3f483be commit 76da6ca

14 files changed

+172
-18
lines changed

docs/source/en/generation_strategies.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,28 @@ To enable assisted decoding, set the `assistant_model` argument with a model.
441441
['Alice and Bob are sitting in a bar. Alice is drinking a beer and Bob is drinking a']
442442
```
443443

444+
<Tip>
445+
446+
If you're using a `pipeline` object, all you need to do is to pass the assistant checkpoint under `assistant_model`
447+
448+
```python
449+
>>> from transformers import pipeline
450+
>>> import torch
451+
452+
>>> pipe = pipeline(
453+
... "text-generation",
454+
... model="meta-llama/Llama-3.1-8B",
455+
... assistant_model="meta-llama/Llama-3.2-1B", # This extra line is all that's needed, also works with UAD
456+
... torch_dtype=torch.bfloat16
457+
>>> )
458+
>>> pipe_output = pipe("Once upon a time, ", max_new_tokens=50, do_sample=False)
459+
>>> pipe_output[0]["generated_text"]
460+
'Once upon a time, 3D printing was a niche technology that was only'
461+
```
462+
463+
</Tip>
464+
465+
444466
When using assisted decoding with sampling methods, you can use the `temperature` argument to control the randomness,
445467
just like in multinomial sampling. However, in assisted decoding, reducing the temperature may help improve the latency.
446468

src/transformers/generation/flax_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,6 @@ def generate(
347347
eos_token_id = generation_config.eos_token_id
348348
if isinstance(eos_token_id, list):
349349
eos_token_id = eos_token_id[0]
350-
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
351350
generation_config.pad_token_id = eos_token_id
352351

353352
if generation_config.decoder_start_token_id is None and self.config.is_encoder_decoder:

src/transformers/generation/tf_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,6 @@ def generate(
773773
eos_token_id = generation_config.eos_token_id
774774
if isinstance(eos_token_id, list):
775775
eos_token_id = eos_token_id[0]
776-
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.")
777776
generation_config.pad_token_id = eos_token_id
778777

779778
use_xla = not tf.executing_eagerly()

src/transformers/pipelines/automatic_speech_recognition.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,12 @@ def _sanitize_parameters(
348348
raise ValueError("Only Whisper can return language for now.")
349349
postprocess_params["return_language"] = return_language
350350

351+
if self.assistant_model is not None:
352+
forward_params["assistant_model"] = self.assistant_model
353+
if self.assistant_tokenizer is not None:
354+
forward_params["tokenizer"] = self.tokenizer
355+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
356+
351357
return preprocess_params, forward_params, postprocess_params
352358

353359
def preprocess(self, inputs, chunk_length_s=0, stride_length_s=None):

src/transformers/pipelines/base.py

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from ..feature_extraction_utils import PreTrainedFeatureExtractor
3434
from ..image_processing_utils import BaseImageProcessor
3535
from ..modelcard import ModelCard
36-
from ..models.auto.configuration_auto import AutoConfig
36+
from ..models.auto import AutoConfig, AutoTokenizer
3737
from ..processing_utils import ProcessorMixin
3838
from ..tokenization_utils import PreTrainedTokenizer
3939
from ..utils import (
@@ -425,6 +425,62 @@ def get_default_model_and_revision(
425425
return default_models[framework]
426426

427427

428+
def load_assistant_model(
429+
model: "PreTrainedModel",
430+
assistant_model: Optional[Union[str, "PreTrainedModel"]],
431+
assistant_tokenizer: Optional[PreTrainedTokenizer],
432+
) -> Tuple[Optional["PreTrainedModel"], Optional[PreTrainedTokenizer]]:
433+
"""
434+
Prepares the assistant model and the assistant tokenizer for a pipeline whose model that can call `generate`.
435+
436+
Args:
437+
model ([`PreTrainedModel`]):
438+
The main model that will be used by the pipeline to make predictions.
439+
assistant_model (`str` or [`PreTrainedModel`], *optional*):
440+
The assistant model that will be used by the pipeline to make predictions.
441+
assistant_tokenizer ([`PreTrainedTokenizer`], *optional*):
442+
The assistant tokenizer that will be used by the pipeline to encode data for the model.
443+
444+
Returns:
445+
Tuple: The loaded assistant model and (optionally) the loaded tokenizer.
446+
"""
447+
if not model.can_generate() or assistant_model is None:
448+
return None, None
449+
450+
if not isinstance(model, PreTrainedModel):
451+
raise ValueError(
452+
"Assisted generation, triggered by the `assistant_model` argument, is only available for "
453+
"`PreTrainedModel` model instances. For instance, TF or JAX models are not supported."
454+
)
455+
456+
# If the model is passed as a string, load the model and the corresponding tokenizer
457+
if isinstance(assistant_model, str):
458+
assistant_config = AutoConfig.from_pretrained(assistant_model)
459+
_, loaded_assistant_model = infer_framework_load_model(assistant_model, config=assistant_config)
460+
loaded_assistant_model = loaded_assistant_model.to(device=model.device, dtype=model.dtype)
461+
loaded_assistant_tokenizer = AutoTokenizer.from_pretrained(assistant_model)
462+
else:
463+
loaded_assistant_model = assistant_model
464+
loaded_assistant_tokenizer = assistant_tokenizer
465+
466+
# Finally, let's check the tokenizers: if the two models have different tokenizers, we need to keep the assistant
467+
# tokenizer
468+
same_vocab_size = model.config.vocab_size == loaded_assistant_model.config.vocab_size
469+
same_special_tokens = all(
470+
getattr(model.config, token) == getattr(loaded_assistant_model.config, token)
471+
for token in ("eos_token_id", "pad_token_id", "bos_token_id")
472+
)
473+
if same_vocab_size and same_special_tokens:
474+
loaded_assistant_tokenizer = None
475+
elif loaded_assistant_tokenizer is None:
476+
raise ValueError(
477+
"The assistant model has a different tokenizer than the main model. You should pass the assistant "
478+
"tokenizer."
479+
)
480+
481+
return loaded_assistant_model, loaded_assistant_tokenizer
482+
483+
428484
class PipelineException(Exception):
429485
"""
430486
Raised by a [`Pipeline`] when handling __call__.
@@ -925,8 +981,13 @@ def __init__(
925981
):
926982
self.model.to(self.device)
927983

928-
# If the model can generate, create a local generation config. This is done to avoid side-effects on the model
929-
# as we apply local tweaks to the generation config.
984+
# If the model can generate:
985+
# 1 - create a local generation config. This is done to avoid side-effects on the model as we apply local
986+
# tweaks to the generation config.
987+
# 2 - load the assistant model if it is passed.
988+
self.assistant_model, self.assistant_tokenizer = load_assistant_model(
989+
self.model, kwargs.pop("assistant_model", None), kwargs.pop("assistant_tokenizer", None)
990+
)
930991
if self.model.can_generate():
931992
self.prefix = self.model.config.prefix if hasattr(self.model.config, "prefix") else None
932993
self.generation_config = copy.deepcopy(self.model.generation_config)

src/transformers/pipelines/document_question_answering.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,14 @@ def _sanitize_parameters(
189189
if handle_impossible_answer is not None:
190190
postprocess_params["handle_impossible_answer"] = handle_impossible_answer
191191

192-
return preprocess_params, {}, postprocess_params
192+
forward_params = {}
193+
if self.assistant_model is not None:
194+
forward_params["assistant_model"] = self.assistant_model
195+
if self.assistant_tokenizer is not None:
196+
forward_params["tokenizer"] = self.tokenizer
197+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
198+
199+
return preprocess_params, forward_params, postprocess_params
193200

194201
def __call__(
195202
self,

src/transformers/pipelines/image_to_text.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,12 @@ def _sanitize_parameters(self, max_new_tokens=None, generate_kwargs=None, prompt
9292
)
9393
forward_params.update(generate_kwargs)
9494

95+
if self.assistant_model is not None:
96+
forward_params["assistant_model"] = self.assistant_model
97+
if self.assistant_tokenizer is not None:
98+
forward_params["tokenizer"] = self.tokenizer
99+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
100+
95101
return preprocess_params, forward_params, {}
96102

97103
def __call__(self, inputs: Union[str, List[str], "Image.Image", List["Image.Image"]] = None, **kwargs):

src/transformers/pipelines/table_question_answering.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,13 @@ def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, *
358358
forward_params = {}
359359
if sequential is not None:
360360
forward_params["sequential"] = sequential
361+
362+
if self.assistant_model is not None:
363+
forward_params["assistant_model"] = self.assistant_model
364+
if self.assistant_tokenizer is not None:
365+
forward_params["tokenizer"] = self.tokenizer
366+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
367+
361368
return preprocess_params, forward_params, {}
362369

363370
def preprocess(self, pipeline_input, sequential=None, padding=True, truncation=None):

src/transformers/pipelines/text2text_generation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ def _sanitize_parameters(
106106
)
107107
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
108108

109+
if self.assistant_model is not None:
110+
forward_params["assistant_model"] = self.assistant_model
111+
if self.assistant_tokenizer is not None:
112+
forward_params["tokenizer"] = self.tokenizer
113+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
114+
109115
return preprocess_params, forward_params, postprocess_params
110116

111117
def check_inputs(self, input_length: int, min_length: int, max_length: int):

src/transformers/pipelines/text_generation.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import enum
22
import itertools
33
import types
4-
import warnings
54
from typing import Dict
65

76
from ..utils import add_end_docstrings, is_tf_available, is_torch_available
@@ -194,12 +193,13 @@ def _sanitize_parameters(
194193

195194
if stop_sequence is not None:
196195
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
197-
if len(stop_sequence_ids) > 1:
198-
warnings.warn(
199-
"Stopping on a multiple token sequence is not yet supported on transformers. The first token of"
200-
" the stop sequence will be used as the stop sequence string in the interim."
201-
)
202-
generate_kwargs["eos_token_id"] = stop_sequence_ids[0]
196+
generate_kwargs["eos_token_id"] = stop_sequence_ids
197+
198+
if self.assistant_model is not None:
199+
forward_params["assistant_model"] = self.assistant_model
200+
if self.assistant_tokenizer is not None:
201+
forward_params["tokenizer"] = self.tokenizer
202+
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
203203

204204
return preprocess_params, forward_params, postprocess_params
205205

0 commit comments

Comments
 (0)