diff --git a/docs/source/en/generation_strategies.md b/docs/source/en/generation_strategies.md index b32b26a281a0..ba59c1228609 100644 --- a/docs/source/en/generation_strategies.md +++ b/docs/source/en/generation_strategies.md @@ -504,6 +504,31 @@ Recommended practices: - Add self-contained examples to enable quick experimentation. - Describe soft-requirements such as if the method only works well with a certain family of models. +### Reusing `generate`’s input preparation + +If you're adding a new decoding loop, you might want to preserve the input preparation present in `generate` (batch expansion, attention masks, logits processors, stopping criteria, etc.). You can also pass a **callable** to `custom_generate` to reuse [`~GenerationMixin.generate`]’s full preparation pipeline while overriding only the decoding loop. + +```py +def custom_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, **model_kwargs): + next_tokens = input_ids + while input_ids.shape[1] < stopping_criteria[0].max_length: + logits = model(next_tokens, attention_mask=attention_mask, **model_kwargs).logits + next_token_logits = logits_processor(input_ids, logits[:, -1, :]) + next_tokens = torch.argmax(next_token_logits, dim=-1)[:, None] + input_ids = torch.cat((input_ids, next_tokens), dim=-1) + attention_mask = torch.cat((attention_mask, torch.ones_like(next_tokens)), dim=-1) + return input_ids + +output = model.generate( + **inputs, + custom_generate=custom_loop, + max_new_tokens=10, +) +``` + +> [!TIP] +> If you publish a `custom_generate` repository, your `generate` implementation can itself define a callable and pass it to `model.generate()`. This lets you customize the decoding loop while still benefiting from Transformers’ built-in input preparation logic. + ### Finding custom generation methods You can find all custom generation methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`. In addition to the tag, we curate two collections of `custom_generate` methods: diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index b83d5a973398..0ae0f333c4a6 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -2165,7 +2165,7 @@ def generate( negative_prompt_ids: Optional[torch.Tensor] = None, negative_prompt_attention_mask: Optional[torch.Tensor] = None, use_model_defaults: Optional[bool] = None, - custom_generate: Optional[str] = None, + custom_generate: Optional[Union[str, Callable]] = None, **kwargs, ) -> Union[GenerateOutput, torch.LongTensor]: r""" @@ -2235,11 +2235,15 @@ def generate( generation configuration (`model.generation_config`), as opposed to the global defaults (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be `True`. - custom_generate (`str`, *optional*): - A string containing the name of a huggingface.co repository. If provided, the custom `generate` - function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the - standard `generate` method. Note that the logic is for generation is entirely defined in that - repository, and the return type may be different from the standard `generate` method. + custom_generate (`str` or `Callable`, *optional*): + One of the following: + - `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at + `custom_generate/generate.py` in that repository instead of the standard `generate` method. The + repository fully replaces the generation logic, and the return type may differ. + - `str` (local repository path): same as above but from a local path, `trust_remote_code` not required. + - `Callable`: `generate` will perform the usual input preparation steps, then call the provided callable to + run the decoding loop. + For more information, see [the docs](../../generation_strategies#custom-generation-methods). kwargs (`dict[str, Any]`, *optional*): Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder @@ -2263,7 +2267,7 @@ def generate( """ # 0. If requested, load an arbitrary generation recipe from the Hub and run it instead trust_remote_code = kwargs.pop("trust_remote_code", None) - if custom_generate is not None: + if custom_generate is not None and isinstance(custom_generate, str): # Get all `generate` arguments in a single variable. Custom functions are responsible for handling them: # they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to # trigger the custom generation. They can access to methods from `GenerationMixin` through `model`. @@ -2360,6 +2364,14 @@ def generate( else: input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + # Expand inputs depending on the generation mode + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=max(generation_config.num_beams, generation_config.num_return_sequences), + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + if generation_config.token_healing: input_ids = self.heal_tokens(input_ids, tokenizer) @@ -2441,7 +2453,18 @@ def generate( model_kwargs["use_cache"] = generation_config.use_cache # 10. go into different generation modes - if generation_mode == GenerationMode.ASSISTED_GENERATION: + if isinstance(custom_generate, Callable): + result = custom_generate( + self, + input_ids, + logits_processor=prepared_logits_processor, + stopping_criteria=prepared_stopping_criteria, + generation_config=generation_config, + synced_gpus=synced_gpus, + streamer=streamer, + **model_kwargs, + ) + elif generation_mode == GenerationMode.ASSISTED_GENERATION: if generation_config.num_return_sequences > 1: raise ValueError( "num_return_sequences has to be 1 when doing assisted generate, " @@ -2530,15 +2553,7 @@ def generate( ) elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH): - # 11. expand input_ids with `num_return_sequences` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - - # 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) + # 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`) result = self._sample( input_ids, logits_processor=prepared_logits_processor, @@ -2550,14 +2565,7 @@ def generate( ) elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH): - # 11. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 12. run beam sample + # 11. run beam sample result = self._beam_search( input_ids, logits_processor=prepared_logits_processor, @@ -2583,14 +2591,6 @@ def generate( num_beam_groups=generation_config.num_beam_groups, max_length=generation_config.max_length, ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search result = self._group_beam_search( input_ids, beam_scorer, @@ -2657,14 +2657,7 @@ def typeerror(): num_beam_hyps_to_keep=generation_config.num_return_sequences, max_length=generation_config.max_length, ) - # 12. interleave input_ids with `num_beams` additional sequences per batch - input_ids, model_kwargs = self._expand_inputs_for_generation( - input_ids=input_ids, - expand_size=generation_config.num_beams, - is_encoder_decoder=self.config.is_encoder_decoder, - **model_kwargs, - ) - # 13. run beam search + # 12. run beam search result = self._constrained_beam_search( input_ids, constrained_beam_scorer=constrained_beam_scorer, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index c6376617341a..c9d20e692e92 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -5044,6 +5044,26 @@ def test_custom_generate_local_directory(self): ) assert value == "success" + def test_custom_generate_callable(self): + """Tests that passing a callable to `custom_generate` executes the callable decoding loop""" + model = AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto" + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM") + model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device) + + def custom_loop(model, input_ids, logits_processor, stopping_criteria, generation_config, **model_kwargs): + # Check that generate() correctly prepares the stopping criteria + assert stopping_criteria[0].max_length == input_ids.shape[1] + 3 + return "callable_success" + + value = model.generate( + **model_inputs, + max_new_tokens=3, + custom_generate=custom_loop, + ) + self.assertEqual(value, "callable_success") + @pytest.mark.generate def test_generate_custom_cache_position(self): """