Skip to content
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
ceaa4b8
factor out expand inputs
manueldeprada Aug 11, 2025
dd4b1b6
Merge branch 'main' of github.com:huggingface/transformers into pass_…
manueldeprada Aug 11, 2025
d988bfa
Merge branch 'main' into pass_decoding_method
manueldeprada Aug 12, 2025
a1e8726
callable arg
manueldeprada Aug 12, 2025
2412d10
Merge branch 'pass_decoding_method' of https://github.com/manueldepra…
manueldeprada Aug 12, 2025
896bf6f
improve docs, add test
manueldeprada Aug 12, 2025
14cf301
ops squeezy file
manueldeprada Aug 12, 2025
f86531e
Update docs/source/en/generation_strategies.md
manueldeprada Aug 12, 2025
be24820
Update docs/source/en/generation_strategies.md
manueldeprada Aug 12, 2025
6ea93a4
Update docs/source/en/generation_strategies.md
manueldeprada Aug 12, 2025
d651f0a
Apply suggestions from code review
manueldeprada Aug 12, 2025
98ca846
Update src/transformers/generation/utils.py
manueldeprada Aug 12, 2025
945a6dc
joao review
manueldeprada Aug 12, 2025
8d60fb1
Merge branch 'pass_decoding_method' of https://github.com/manueldepra…
manueldeprada Aug 12, 2025
9089347
ruff
manueldeprada Aug 12, 2025
2759a48
Update src/transformers/generation/utils.py
manueldeprada Aug 13, 2025
e5ba4eb
Merge branch 'main' into pass_decoding_method
manueldeprada Aug 13, 2025
0fdb754
Update src/transformers/generation/utils.py
manueldeprada Aug 13, 2025
b8cc256
Update src/transformers/generation/utils.py
manueldeprada Aug 13, 2025
d6ae375
Apply suggestions from code review
manueldeprada Aug 13, 2025
31c4559
ruff
manueldeprada Aug 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions docs/source/en/generation_strategies.md
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,30 @@ Recommended practices:
- Add self-contained examples to enable quick experimentation.
- Describe soft-requirements such as if the method only works well with a certain family of models.

### Reusing `generate`’s input preparation

If you're adding a new decoding loop, you might want to preserve the input preparation present in `generate` (batch expansion, attention masks, logits processors, stopping criteria, etc.). You can also pass a **callable** to `custom_generate` to reuse [`~GenerationMixin.generate`]’s full preparation pipeline while overriding only the decoding loop.

```py
def custom_loop(model, input_ids, attention_mask, logits_processor, stopping_criteria, generation_config, **model_kwargs):
next_tokens = input_ids
while input_ids.shape[1] < stopping_criteria[0].max_length:
logits = model(next_tokens, attention_mask=attention_mask, **model_kwargs).logits
next_token_logits = logits_processor(input_ids, logits[:, -1, :])
next_tokens = torch.argmax(next_token_logits, dim=-1)[:, None]
input_ids = torch.cat((input_ids, next_tokens), dim=-1)
attention_mask = torch.cat((attention_mask, torch.ones_like(next_tokens)), dim=-1)
return input_ids

output = model.generate(
**inputs,
custom_generate=custom_loop,
max_new_tokens=10,
)
```

> [!TIP]
> If you publish a `custom_generate` repository, your `generate` implementation can itself define a callable and pass it to `model.generate()`. This lets you customize the decoding loop while still benefiting from Transformers’ built-in input preparation logic.

## Resources

Expand Down
76 changes: 35 additions & 41 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2157,7 +2157,7 @@ def generate(
negative_prompt_ids: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
use_model_defaults: Optional[bool] = None,
custom_generate: Optional[str] = None,
custom_generate: Optional[Union[str, Callable]] = None,
**kwargs,
) -> Union[GenerateOutput, torch.LongTensor]:
r"""
Expand Down Expand Up @@ -2227,11 +2227,16 @@ def generate(
generation configuration (`model.generation_config`), as opposed to the global defaults
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
`True`.
custom_generate (`str`, *optional*):
A string containing the name of a huggingface.co repository. If provided, the custom `generate`
function defined in that reposity's `custom_generate/generate.py` file will be executed instead of the
standard `generate` method. Note that the logic is for generation is entirely defined in that
repository, and the return type may be different from the standard `generate` method.
custom_generate (`str` or `Callable`, *optional*):
One of the following:

- `str` (Hugging Face Hub repository name): runs the custom `generate` function defined at
`custom_generate/generate.py` in that repository instead of the standard `generate` method. The
repository fully replaces the generation logic, and the return type may differ.
- `str` (local repository path): same as above but from a local path, `trust_remote_code` not required.
- `Callable`: `generate` will perform the usual preparation steps, then call the provided callable to
run the decoding loop. For more information, see
[Reusing generate's preparation steps by passing a callable](https://huggingface.co/docs/transformers/en/generation_strategies#reusing-generate-s-preparation-steps-by-passing-a-callable).
kwargs (`dict[str, Any]`, *optional*):
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
Expand All @@ -2255,7 +2260,7 @@ def generate(
"""
# 0. If requested, load an arbitrary generation recipe from the Hub and run it instead
trust_remote_code = kwargs.pop("trust_remote_code", None)
if custom_generate is not None:
if custom_generate is not None and isinstance(custom_generate, str):
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
Expand Down Expand Up @@ -2352,6 +2357,14 @@ def generate(
else:
input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")

# Expand inputs depending on the generation mode
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=max(generation_config.num_beams, generation_config.num_return_sequences),
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)

if generation_config.token_healing:
input_ids = self.heal_tokens(input_ids, tokenizer)

Expand Down Expand Up @@ -2433,7 +2446,18 @@ def generate(
model_kwargs["use_cache"] = generation_config.use_cache

# 10. go into different generation modes
if generation_mode == GenerationMode.ASSISTED_GENERATION:
if isinstance(custom_generate, Callable):
result = custom_generate(
self,
input_ids,
logits_processor=prepared_logits_processor,
stopping_criteria=prepared_stopping_criteria,
generation_config=generation_config,
synced_gpus=synced_gpus,
streamer=streamer,
**model_kwargs,
)
elif generation_mode == GenerationMode.ASSISTED_GENERATION:
if generation_config.num_return_sequences > 1:
raise ValueError(
"num_return_sequences has to be 1 when doing assisted generate, "
Expand Down Expand Up @@ -2522,15 +2546,7 @@ def generate(
)

elif generation_mode in (GenerationMode.SAMPLE, GenerationMode.GREEDY_SEARCH):
# 11. expand input_ids with `num_return_sequences` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_return_sequences,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)

# 12. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
# 11. run sample (it degenerates to greedy search when `generation_config.do_sample=False`)
result = self._sample(
input_ids,
logits_processor=prepared_logits_processor,
Expand All @@ -2542,14 +2558,7 @@ def generate(
)

elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
# 11. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 12. run beam sample
# 11. run beam sample
result = self._beam_search(
input_ids,
logits_processor=prepared_logits_processor,
Expand All @@ -2575,14 +2584,6 @@ def generate(
num_beam_groups=generation_config.num_beam_groups,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
result = self._group_beam_search(
input_ids,
beam_scorer,
Expand Down Expand Up @@ -2649,14 +2650,7 @@ def typeerror():
num_beam_hyps_to_keep=generation_config.num_return_sequences,
max_length=generation_config.max_length,
)
# 12. interleave input_ids with `num_beams` additional sequences per batch
input_ids, model_kwargs = self._expand_inputs_for_generation(
input_ids=input_ids,
expand_size=generation_config.num_beams,
is_encoder_decoder=self.config.is_encoder_decoder,
**model_kwargs,
)
# 13. run beam search
# 12. run beam search
result = self._constrained_beam_search(
input_ids,
constrained_beam_scorer=constrained_beam_scorer,
Expand Down
20 changes: 20 additions & 0 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5044,6 +5044,26 @@ def test_custom_generate_local_directory(self):
)
assert value == "success"

def test_custom_generate_callable(self):
"""Tests that passing a callable to `custom_generate` executes the callable decoding loop"""
model = AutoModelForCausalLM.from_pretrained(
"hf-internal-testing/tiny-random-MistralForCausalLM", device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-MistralForCausalLM")
model_inputs = tokenizer("Hello, world!", return_tensors="pt").to(model.device)

def custom_loop(model, input_ids, logits_processor, stopping_criteria, generation_config, **model_kwargs):
# Check that generate() correctly prepares the stopping criteria
assert stopping_criteria[0].max_length == input_ids.shape[1] + 3
return "callable_success"

value = model.generate(
**model_inputs,
max_new_tokens=3,
custom_generate=custom_loop,
)
self.assertEqual(value, "callable_success")

@pytest.mark.generate
def test_generate_custom_cache_position(self):
"""
Expand Down