Skip to content

Commit 0dc1d8b

Browse files
seonglaexenova
andauthored
Add an example and type enhancement for TextStreamer (#1066)
* typing: GenerationConfig option for TextStreamer * docs: streaming example with following the style * docs: streaming description from @xenova's suggestion Co-authored-by: Joshua Lochner <[email protected]> * fix: streaming example from @xenova's suggestion Co-authored-by: Joshua Lochner <[email protected]> * fix: <pre> tag by wrapping it in a <detail> tag * fix: remove newlines for proper rendering --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent 31ce759 commit 0dc1d8b

File tree

3 files changed

+78
-2
lines changed

3 files changed

+78
-2
lines changed

docs/source/pipelines.md

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,70 @@ Cheddar is my go-to for any occasion or mood;
148148
It adds depth and richness without being overpowering its taste buds alone
149149
```
150150

151+
### Streaming
152+
153+
Some pipelines such as `text-generation` or `automatic-speech-recognition` support streaming output. This is achieved using the `TextStreamer` class. For example, when using a chat model like `Qwen2.5-Coder-0.5B-Instruct`, you can specify a callback function that will be called with each generated token text (if unset, new tokens will be printed to the console).
154+
155+
```js
156+
import { pipeline, TextStreamer } from "@huggingface/transformers";
157+
158+
// Create a text generation pipeline
159+
const generator = await pipeline(
160+
"text-generation",
161+
"onnx-community/Qwen2.5-Coder-0.5B-Instruct",
162+
{ dtype: "q4" },
163+
);
164+
165+
// Define the list of messages
166+
const messages = [
167+
{ role: "system", content: "You are a helpful assistant." },
168+
{ role: "user", content: "Write a quick sort algorithm." },
169+
];
170+
171+
// Create text streamer
172+
const streamer = new TextStreamer(generator.tokenizer, {
173+
skip_prompt: true,
174+
// Optionally, do something with the text (e.g., write to a textbox)
175+
// callback_function: (text) => { /* Do something with text */ },
176+
})
177+
178+
// Generate a response
179+
const result = await generator(messages, { max_new_tokens: 512, do_sample: false, streamer });
180+
```
181+
182+
Logging `result[0].generated_text` to the console gives:
183+
184+
185+
<details>
186+
<summary>Click to view the console output</summary>
187+
<pre>
188+
Here's a simple implementation of the quick sort algorithm in Python:
189+
```python
190+
def quick_sort(arr):
191+
if len(arr) <= 1:
192+
return arr
193+
pivot = arr[len(arr) // 2]
194+
left = [x for x in arr if x < pivot]
195+
middle = [x for x in arr if x == pivot]
196+
right = [x for x in arr if x > pivot]
197+
return quick_sort(left) + middle + quick_sort(right)
198+
# Example usage:
199+
arr = [3, 6, 8, 10, 1, 2]
200+
sorted_arr = quick_sort(arr)
201+
print(sorted_arr)
202+
```
203+
### Explanation:
204+
- **Base Case**: If the array has less than or equal to one element (i.e., `len(arr)` is less than or equal to `1`), it is already sorted and can be returned as is.
205+
- **Pivot Selection**: The pivot is chosen as the middle element of the array.
206+
- **Partitioning**: The array is partitioned into three parts: elements less than the pivot (`left`), elements equal to the pivot (`middle`), and elements greater than the pivot (`right`). These partitions are then recursively sorted.
207+
- **Recursive Sorting**: The subarrays are sorted recursively using `quick_sort`.
208+
This approach ensures that each recursive call reduces the problem size by half until it reaches a base case.
209+
</pre>
210+
</details>
211+
212+
This streaming feature allows you to process the output as it is generated, rather than waiting for the entire output to be generated before processing it.
213+
214+
151215
For more information on the available options for each pipeline, refer to the [API Reference](./api/pipelines).
152216
If you would like more control over the inference process, you can use the [`AutoModel`](./api/models), [`AutoTokenizer`](./api/tokenizers), or [`AutoProcessor`](./api/processors) classes instead.
153217

src/generation/configuration_utils.js

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,13 @@ export class GenerationConfig {
259259
*/
260260
suppress_tokens = null;
261261

262+
/**
263+
* A streamer that will be used to stream the generation.
264+
* @type {import('./streamers.js').TextStreamer}
265+
* @default null
266+
*/
267+
streamer = null;
268+
262269
/**
263270
* A list of tokens that will be suppressed at the beginning of the generation.
264271
* The `SuppressBeginTokens` logit processor will set their log probs to `-inf` so that they are not sampled.

src/generation/streamers.js

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ const stdout_write = apis.IS_PROCESS_AVAILABLE
3434
export class TextStreamer extends BaseStreamer {
3535
/**
3636
*
37-
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
37+
* @param {import('../tokenizers.js').PreTrainedTokenizer} tokenizer
38+
* @param {Object} options
39+
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
40+
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
41+
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
42+
* @param {Object} [options.decode_kwargs={}] Additional keyword arguments to pass to the tokenizer's decode method
3843
*/
3944
constructor(tokenizer, {
4045
skip_prompt = false,
@@ -143,7 +148,7 @@ export class WhisperTextStreamer extends TextStreamer {
143148
* @param {Object} options
144149
* @param {boolean} [options.skip_prompt=false] Whether to skip the prompt tokens
145150
* @param {function(string): void} [options.callback_function=null] Function to call when a piece of text is ready to display
146-
* @param {function(string): void} [options.token_callback_function=null] Function to call when a new token is generated
151+
* @param {function(bigint[]): void} [options.token_callback_function=null] Function to call when a new token is generated
147152
* @param {function(number): void} [options.on_chunk_start=null] Function to call when a new chunk starts
148153
* @param {function(number): void} [options.on_chunk_end=null] Function to call when a chunk ends
149154
* @param {function(): void} [options.on_finalize=null] Function to call when the stream is finalized

0 commit comments

Comments
 (0)