Skip to content

Commit e78571f

Browse files
manueldepradagante
andauthored
decoding_method argument in generate (#40085)
* factor out expand inputs * callable arg * improve docs, add test * Update docs/source/en/generation_strategies.md Co-authored-by: Joao Gante <[email protected]> --------- Co-authored-by: Joao Gante <[email protected]>
1 parent 8d19231 commit e78571f

File tree

3 files changed

+79
-41
lines changed

3 files changed

+79
-41
lines changed

docs/source/en/generation_strategies.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,31 @@ Recommended practices:
504504
- Add self-contained examples to enable quick experimentation.
505505
- Describe soft-requirements such as if the method only works well with a certain family of models.
506506

507+
### Reusing `generate`’s input preparation
508+
509+
If you're adding a new decoding loop, you might want to preserve the input preparation present in `generate` (batch expansion, attention masks, logits processors, stopping criteria, etc.). You can also pass a **callable** to `custom_generate` to reuse [`~GenerationMixin.generate`]’s full preparation pipeline while overriding only the decoding loop.
510+
511+
```py
512+
def custom_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, **model_kwargs):
513+
next_tokens = input_ids
514+
while input_ids.shape[1] < stopping_criteria[0].max_length:
515+
logits = model(next_tokens, attention_mask=attention_mask, **model_kwargs).logits
516+
next_token_logits = logits_processor(input_ids, logits[:, -1, :])
517+
next_tokens = torch.argmax(next_token_logits, dim=-1)[:, None]
518+
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
519+
attention_mask = torch.cat((attention_mask, torch.ones_like(next_tokens)), dim=-1)
520+
return input_ids
521+
522+
output = model.generate(
523+
**inputs,
524+
custom_generate=custom_loop,
525+
max_new_tokens=10,
526+
)
527+
```
528+
529+
> [!TIP]
530+
> If you publish a `custom_generate` repository, your `generate` implementation can itself define a callable and pass it to `model.generate()`. This lets you customize the decoding loop while still benefiting from Transformers’ built-in input preparation logic.
531+
507532
### Finding custom generation methods
508533

509534
You can find all custom generation methods by [searching for their custom tag.](https://huggingface.co/models?other=custom_generate), `custom_generate`. In addition to the tag, we curate two collections of `custom_generate` methods:

src/transformers/generation/utils.py

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2165,7 +2165,7 @@ def generate(
21652165
negative_prompt_ids: Optional[torch.Tensor] = None,
21662166
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
21672167
use_model_defaults: Optional[bool] = None,
2168-
custom_generate: Optional[str] = None,
2168+
custom_generate: Optional[Union[str, Callable]] = None,
21692169
**kwargs,
21702170
) -> Union[GenerateOutput, torch.LongTensor]:
21712171
r"""
@@ -2235,11 +2235,15 @@ def generate(
22352235
generation configuration (`model.generation_config`), as opposed to the global defaults
22362236
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
22372237
`True`.
2238-
custom_generate (`str`, *optional*):
2239-
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
2240-
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
2241-
standard `generate` method. Note that the logic is for generation is entirely defined in that
2242-
repository, and the return type may be different from the standard `generate` method.
2238+
custom_generate (`str` or `Callable`, *optional*):
2239+
One of the following:
2240+
- `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at
2241+
`custom_generate/generate.py` in that repository instead of the standard `generate` method. The
2242+
repository fully replaces the generation logic, and the return type may differ.
2243+
- `str` (local repository path): same as above but from a local path, `trust_remote_code` not required.
2244+
- `Callable`: `generate` will perform the usual input preparation steps, then call the provided callable to
2245+
run the decoding loop.
2246+
For more information, see [the docs](../../generation_strategies#custom-generation-methods).
22432247
kwargs (`dict[str, Any]`, *optional*):
22442248
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
22452249
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -2263,7 +2267,7 @@ def generate(
22632267
"""
22642268
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
22652269
trust_remote_code = kwargs.pop("trust_remote_code", None)
2266-
if custom_generate is not None:
2270+
if custom_generate is not None and isinstance(custom_generate, str):
22672271
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
22682272
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
22692273
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
@@ -2360,6 +2364,14 @@ def generate(
23602364
else:
23612365
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
23622366

2367+
# Expand inputs depending on the generation mode
2368+
input_ids, model_kwargs = self._expand_inputs_for_generation(
2369+
input_ids=input_ids,
2370+
expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
2371+
is_encoder_decoder=self.config.is_encoder_decoder,
2372+
**model_kwargs,
2373+
)
2374+
23632375
if generation_config.token_healing:
23642376
input_ids = self.heal_tokens(input_ids, tokenizer)
23652377

@@ -2441,7 +2453,18 @@ def generate(
24412453
model_kwargs["use_cache"] = generation_config.use_cache
24422454

24432455
# 10. go into different generation modes
2444-
if generation_mode == GenerationMode.ASSISTED_GENERATION:
2456+
if isinstance(custom_generate, Callable):
2457+
result = custom_generate(
2458+
self,
2459+
input_ids,
2460+
logits_processor=prepared_logits_processor,
2461+
stopping_criteria=prepared_stopping_criteria,
2462+
generation_config=generation_config,
2463+
synced_gpus=synced_gpus,
2464+
streamer=streamer,
2465+
**model_kwargs,
2466+
)
2467+
elif generation_mode == GenerationMode.ASSISTED_GENERATION:
24452468
if generation_config.num_return_sequences > 1:
24462469
raise ValueError(
24472470
"num_return_sequences has to be 1 when doing assisted generate, "
@@ -2530,15 +2553,7 @@ def generate(
25302553
)
25312554

25322555
elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
2533-
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
2534-
input_ids, model_kwargs = self._expand_inputs_for_generation(
2535-
input_ids=input_ids,
2536-
expand_size=generation_config.num_return_sequences,
2537-
is_encoder_decoder=self.config.is_encoder_decoder,
2538-
**model_kwargs,
2539-
)
2540-
2541-
# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
2556+
# 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
25422557
result = self._sample(
25432558
input_ids,
25442559
logits_processor=prepared_logits_processor,
@@ -2550,14 +2565,7 @@ def generate(
25502565
)
25512566

25522567
elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2553-
# 11. interleave input_ids with `num_beams` additional sequences per batch
2554-
input_ids, model_kwargs = self._expand_inputs_for_generation(
2555-
input_ids=input_ids,
2556-
expand_size=generation_config.num_beams,
2557-
is_encoder_decoder=self.config.is_encoder_decoder,
2558-
**model_kwargs,
2559-
)
2560-
# 12. run beam sample
2568+
# 11. run beam sample
25612569
result = self._beam_search(
25622570
input_ids,
25632571
logits_processor=prepared_logits_processor,
@@ -2583,14 +2591,6 @@ def generate(
25832591
num_beam_groups=generation_config.num_beam_groups,
25842592
max_length=generation_config.max_length,
25852593
)
2586-
# 12. interleave input_ids with `num_beams` additional sequences per batch
2587-
input_ids, model_kwargs = self._expand_inputs_for_generation(
2588-
input_ids=input_ids,
2589-
expand_size=generation_config.num_beams,
2590-
is_encoder_decoder=self.config.is_encoder_decoder,
2591-
**model_kwargs,
2592-
)
2593-
# 13. run beam search
25942594
result = self._group_beam_search(
25952595
input_ids,
25962596
beam_scorer,
@@ -2657,14 +2657,7 @@ def typeerror():
26572657
num_beam_hyps_to_keep=generation_config.num_return_sequences,
26582658
max_length=generation_config.max_length,
26592659
)
2660-
# 12. interleave input_ids with `num_beams` additional sequences per batch
2661-
input_ids, model_kwargs = self._expand_inputs_for_generation(
2662-
input_ids=input_ids,
2663-
expand_size=generation_config.num_beams,
2664-
is_encoder_decoder=self.config.is_encoder_decoder,
2665-
**model_kwargs,
2666-
)
2667-
# 13. run beam search
2660+
# 12. run beam search
26682661
result = self._constrained_beam_search(
26692662
input_ids,
26702663
constrained_beam_scorer=constrained_beam_scorer,

tests/generation/test_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5044,6 +5044,26 @@ def test_custom_generate_local_directory(self):
50445044
)
50455045
assert value == "success"
50465046

5047+
def test_custom_generate_callable(self):
5048+
"""Tests that passing a callable to `custom_generate` executes the callable decoding loop"""
5049+
model = AutoModelForCausalLM.from_pretrained(
5050+
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
5051+
)
5052+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
5053+
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)
5054+
5055+
def custom_loop(model, input_ids, logits_processor, stopping_criteria, generation_config, **model_kwargs):
5056+
# Check that generate() correctly prepares the stopping criteria
5057+
assert stopping_criteria[0].max_length == input_ids.shape[1] + 3
5058+
return "callable_success"
5059+
5060+
value = model.generate(
5061+
**model_inputs,
5062+
max_new_tokens=3,
5063+
custom_generate=custom_loop,
5064+
)
5065+
self.assertEqual(value, "callable_success")
5066+
50475067
@pytest.mark.generate
50485068
def test_generate_custom_cache_position(self):
50495069
"""

0 commit comments

Comments
 (0)