Skip to content

Commit 54cd867

Browse files
authored
[custom_generate] don't forward custom_generate and trust_remote_code (#38304)
* prevent infinite loops * docs * more links to custom generation methods
1 parent 135163e commit 54cd867

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

docs/source/en/generation_strategies.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,6 @@ We enable custom decoding methods through model repositories, assuming a specifi
327327

328328
If a model repository holds a custom decoding method, the easiest way to try it out is to load the model and generate with it:
329329

330-
<!-- TODO before merging: 1) better repo name (use a `generate-community` org?) 2) prettify the repo -->
331330
```py
332331
from transformers import AutoModelForCausalLM, AutoTokenizer
333332

@@ -430,7 +429,7 @@ This is the core of your decoding method. It *must* contain a method named `gene
430429
> [!WARNING]
431430
> `generate.py` must be placed in a folder named `custom_generate`, and not at the root level of the repository. The file paths for this feature are hardcoded.
432431
433-
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method.
432+
Under the hood, when the base [`~GenerationMixin.generate`] method is called with a `custom_generate` argument, it first checks its Python requirements (if any), then locates the custom `generate` method in `generate.py`, and finally calls the custom `generate`. All received arguments and `model` are forwarded to your custom `generate` method, with the exception of the arguments used to trigger the custom generation (`trust_remote_code` and `custom_generate`).
434433

435434
This means your `generate` can have a mix of original and custom arguments (as well as a different output type) as shown below.
436435

docs/source/en/llm_tutorial.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,14 +84,17 @@ GenerationConfig {
8484
}
8585
```
8686
87-
You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. Some of the most commonly adjusted parameters are [max_new_tokens](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.max_new_tokens), [num_beams](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_beams), [do_sample](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.do_sample), and [num_return_sequences](https://huggingface.co/docs/transformers/main_classes/text_generation#transformers.GenerationConfig.num_return_sequences).
87+
You can customize [`~GenerationMixin.generate`] by overriding the parameters and values in [`GenerationConfig`]. See [this section below](#common-options) for commonly adjusted parameters.
8888
8989
```py
9090
# enable beam search sampling strategy
9191
model.generate(**inputs, num_beams=4, do_sample=True)
9292
```
9393
94-
[`~GenerationMixin.generate`] can also be extended with external libraries or custom code. The `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution. `stopping_criteria` supports custom [`StoppingCriteria`] to stop text generation. Check out the [logits-processor-zoo](https://github.com/NVIDIA/logits-processor-zoo) for more examples of external [`~GenerationMixin.generate`]-compatible extensions.
94+
[`~GenerationMixin.generate`] can also be extended with external libraries or custom code:
95+
1. the `logits_processor` parameter accepts custom [`LogitsProcessor`] instances for manipulating the next token probability distribution;
96+
2. the `stopping_criteria` parameters supports custom [`StoppingCriteria`] to stop text generation;
97+
3. other custom generation methods can be loaded through the `custom_generate` flag ([docs](generation_strategies.md/#custom-decoding-methods)).
9598
9699
Refer to the [Generation strategies](./generation_strategies) guide to learn more about search, sampling, and decoding strategies.
97100

src/transformers/generation/utils.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,9 +2347,15 @@ def generate(
23472347
if custom_generate is not None:
23482348
trust_remote_code = kwargs.pop("trust_remote_code", None)
23492349
# Get all `generate` arguments in a single variable. Custom functions are responsible for handling them:
2350-
# they receive the same inputs as `generate`, only with `model` instead of `self`. They can access to
2351-
# methods from `GenerationMixin` through `model`.
2352-
global_keys_to_exclude = {"self", "kwargs"}
2350+
# they receive the same inputs as `generate`, with `model` instead of `self` and excluding the arguments to
2351+
# trigger the custom generation. They can access to methods from `GenerationMixin` through `model`.
2352+
global_keys_to_exclude = {
2353+
"self",
2354+
"kwargs",
2355+
"global_keys_to_exclude",
2356+
"trust_remote_code",
2357+
"custom_generate",
2358+
}
23532359
generate_arguments = {key: value for key, value in locals().items() if key not in global_keys_to_exclude}
23542360
generate_arguments.update(kwargs)
23552361

0 commit comments

Comments
 (0)