Skip to content

Commit bea807a

Browse files
SBrandeisMishiggary149
authored
✨ [Widgets] Enable streaming in the conversational widget (#486)
Linked to #360 #410 Should unlock the conversational widget on Mistral if I'm not mistaken? # TL;DR - Leverage inference types from `@huggingface/task` to type input and output of the inference client - Use the inference client to call the inference serverless API - Use the streaming API when supported for the model --------- Co-authored-by: Mishig <[email protected]> Co-authored-by: Victor Mustar <[email protected]>
1 parent bab7c35 commit bea807a

File tree

12 files changed

+255
-159
lines changed

12 files changed

+255
-159
lines changed

packages/inference/src/tasks/nlp/textGeneration.ts

Lines changed: 5 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,71 +1,15 @@
1+
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks/src/tasks/text-generation/inference";
12
import { InferenceOutputError } from "../../lib/InferenceOutputError";
23
import type { BaseArgs, Options } from "../../types";
34
import { request } from "../custom/request";
45

5-
export type TextGenerationArgs = BaseArgs & {
6-
/**
7-
* A string to be generated from
8-
*/
9-
inputs: string;
10-
parameters?: {
11-
/**
12-
* (Optional: True). Bool. Whether or not to use sampling, use greedy decoding otherwise.
13-
*/
14-
do_sample?: boolean;
15-
/**
16-
* (Default: None). Int (0-250). The amount of new tokens to be generated, this does not include the input length it is a estimate of the size of generated text you want. Each new tokens slows down the request, so look for balance between response times and length of text generated.
17-
*/
18-
max_new_tokens?: number;
19-
/**
20-
* (Default: None). Float (0-120.0). The amount of time in seconds that the query should take maximum. Network can cause some overhead so it will be a soft limit. Use that in combination with max_new_tokens for best results.
21-
*/
22-
max_time?: number;
23-
/**
24-
* (Default: 1). Integer. The number of proposition you want to be returned.
25-
*/
26-
num_return_sequences?: number;
27-
/**
28-
* (Default: None). Float (0.0-100.0). The more a token is used within generation the more it is penalized to not be picked in successive generation passes.
29-
*/
30-
repetition_penalty?: number;
31-
/**
32-
* (Default: True). Bool. If set to False, the return results will not contain the original query making it easier for prompting.
33-
*/
34-
return_full_text?: boolean;
35-
/**
36-
* (Default: 1.0). Float (0.0-100.0). The temperature of the sampling operation. 1 means regular sampling, 0 means always take the highest score, 100.0 is getting closer to uniform probability.
37-
*/
38-
temperature?: number;
39-
/**
40-
* (Default: None). Integer to define the top tokens considered within the sample operation to create new text.
41-
*/
42-
top_k?: number;
43-
/**
44-
* (Default: None). Float to define the tokens that are within the sample operation of text generation. Add tokens in the sample for more probable to least probable until the sum of the probabilities is greater than top_p.
45-
*/
46-
top_p?: number;
47-
/**
48-
* (Default: None). Integer. The maximum number of tokens from the input.
49-
*/
50-
truncate?: number;
51-
/**
52-
* (Default: []) List of strings. The model will stop generating text when one of the strings in the list is generated.
53-
* **/
54-
stop_sequences?: string[];
55-
};
56-
};
57-
58-
export interface TextGenerationOutput {
59-
/**
60-
* The continuated string
61-
*/
62-
generated_text: string;
63-
}
64-
656
/**
667
* Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
678
*/
68-
export async function textGeneration(args: TextGenerationArgs, options?: Options): Promise<TextGenerationOutput> {
9+
export async function textGeneration(
10+
args: BaseArgs & TextGenerationInput,
11+
options?: Options
12+
): Promise<TextGenerationOutput> {
6913
const res = await request<TextGenerationOutput[]>(args, {
7014
...options,
7115
taskHint: "text-generation",

packages/inference/src/tasks/nlp/textGenerationStream.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
import type { Options } from "../../types";
1+
import type { BaseArgs, Options } from "../../types";
22
import { streamingRequest } from "../custom/streamingRequest";
3-
import type { TextGenerationArgs } from "./textGeneration";
3+
4+
import type { TextGenerationInput } from "@huggingface/tasks/src/tasks/text-generation/inference";
45

56
export interface TextGenerationStreamToken {
67
/** Token ID from the model tokenizer */
@@ -85,7 +86,7 @@ export interface TextGenerationStreamOutput {
8586
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
8687
*/
8788
export async function* textGenerationStream(
88-
args: TextGenerationArgs,
89+
args: BaseArgs & TextGenerationInput,
8990
options?: Options
9091
): AsyncGenerator<TextGenerationStreamOutput> {
9192
yield* streamingRequest<TextGenerationStreamOutput>(args, {

packages/tasks/src/tasks/index.ts

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,64 @@ import zeroShotClassification from "./zero-shot-classification/data";
3636
import zeroShotImageClassification from "./zero-shot-image-classification/data";
3737
import zeroShotObjectDetection from "./zero-shot-object-detection/data";
3838

39+
export type * from "./audio-classification/inference";
40+
export type * from "./automatic-speech-recognition/inference";
41+
export type * from "./document-question-answering/inference";
42+
export type * from "./feature-extraction/inference";
43+
export type * from "./fill-mask/inference";
44+
export type {
45+
ImageClassificationInput,
46+
ImageClassificationOutput,
47+
ImageClassificationOutputElement,
48+
ImageClassificationParameters,
49+
} from "./image-classification/inference";
50+
export type * from "./image-to-image/inference";
51+
export type { ImageToTextInput, ImageToTextOutput, ImageToTextParameters } from "./image-to-text/inference";
52+
export type * from "./image-segmentation/inference";
53+
export type * from "./object-detection/inference";
54+
export type * from "./depth-estimation/inference";
55+
export type * from "./question-answering/inference";
56+
export type * from "./sentence-similarity/inference";
57+
export type * from "./summarization/inference";
58+
export type * from "./table-question-answering/inference";
59+
export type { TextToImageInput, TextToImageOutput, TextToImageParameters } from "./text-to-image/inference";
60+
export type { TextToAudioParameters, TextToSpeechInput, TextToSpeechOutput } from "./text-to-speech/inference";
61+
export type * from "./token-classification/inference";
62+
export type {
63+
Text2TextGenerationParameters,
64+
Text2TextGenerationTruncationStrategy,
65+
TranslationInput,
66+
TranslationOutput,
67+
} from "./translation/inference";
68+
export type {
69+
ClassificationOutputTransform,
70+
TextClassificationInput,
71+
TextClassificationOutput,
72+
TextClassificationOutputElement,
73+
TextClassificationParameters,
74+
} from "./text-classification/inference";
75+
export type {
76+
FinishReason,
77+
PrefillToken,
78+
TextGenerationInput,
79+
TextGenerationOutput,
80+
TextGenerationOutputDetails,
81+
TextGenerationParameters,
82+
TextGenerationSequenceDetails,
83+
Token,
84+
} from "./text-generation/inference";
85+
export type * from "./video-classification/inference";
86+
export type * from "./visual-question-answering/inference";
87+
export type * from "./zero-shot-classification/inference";
88+
export type * from "./zero-shot-image-classification/inference";
89+
export type {
90+
BoundingBox,
91+
ZeroShotObjectDetectionInput,
92+
ZeroShotObjectDetectionInputData,
93+
ZeroShotObjectDetectionOutput,
94+
ZeroShotObjectDetectionOutputElement,
95+
} from "./zero-shot-object-detection/inference";
96+
3997
import type { ModelLibraryKey } from "../model-libraries";
4098

4199
/**

packages/widgets/package.json

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
],
4747
"dependencies": {
4848
"@huggingface/tasks": "workspace:^",
49-
"@huggingface/jinja": "workspace:^"
49+
"@huggingface/jinja": "workspace:^",
50+
"@huggingface/inference": "workspace:^"
5051
},
5152
"peerDependencies": {
5253
"svelte": "^3.59.2"
@@ -69,7 +70,7 @@
6970
"svelte": "^3.59.2",
7071
"svelte-check": "^3.6.0",
7172
"svelte-preprocess": "^5.1.1",
72-
"tailwindcss": "^3.3.5",
73+
"tailwindcss": "^3.4.1",
7374
"tslib": "^2.4.1",
7475
"vite": "^4.5.0",
7576
"vite-plugin-dts": "^3.6.4"

0 commit comments

Comments
 (0)