Skip to content

Commit b033ad1

Browse files
vvmnnnkvcoyotte508
andauthored
Add token streaming for text generation (#130)
Co-authored-by: coyotte508 <[email protected]>
1 parent 4edddd2 commit b033ad1

File tree

8 files changed

+849
-7
lines changed

8 files changed

+849
-7
lines changed

packages/inference/.eslintignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
dist
2-
tapes.json
2+
tapes.json
3+
src/vendor

packages/inference/.prettierignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ pnpm-lock.yaml
22
# In order to avoid code samples to have tabs, they don't display well on npm
33
README.md
44
dist
5-
tapes.json
5+
tapes.json
6+
src/vendor

packages/inference/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ await hf.textGeneration({
7676
inputs: 'The answer to the universe is'
7777
})
7878

79+
for await const (output of hf.textGenerationStream({
80+
model: "google/flan-t5-xxl",
81+
inputs: 'repeat "one two three four"'
82+
})) {
83+
console.log(output.token.text, output.generated_text);
84+
}
85+
7986
await hf.tokenClassification({
8087
model: 'dbmdz/bert-large-cased-finetuned-conll03-english',
8188
inputs: 'My name is Sarah Jessica Parker but you can call me Jessica'

packages/inference/src/HfInference.ts

Lines changed: 182 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
import { toArray } from "./utils/to-array";
2+
import type { EventSourceMessage } from "./vendor/fetch-event-source/parse";
3+
import { getLines, getMessages } from "./vendor/fetch-event-source/parse";
4+
5+
const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/";
26

37
export interface Options {
48
/**
@@ -223,6 +227,86 @@ export interface TextGenerationReturn {
223227
generated_text: string;
224228
}
225229

230+
export interface TextGenerationStreamToken {
231+
/** Token ID from the model tokenizer */
232+
id: number;
233+
/** Token text */
234+
text: string;
235+
/** Logprob */
236+
logprob: number;
237+
/**
238+
* Is the token a special token
239+
* Can be used to ignore tokens when concatenating
240+
*/
241+
special: boolean;
242+
}
243+
244+
export interface TextGenerationStreamPrefillToken {
245+
/** Token ID from the model tokenizer */
246+
id: number;
247+
/** Token text */
248+
text: string;
249+
/**
250+
* Logprob
251+
* Optional since the logprob of the first token cannot be computed
252+
*/
253+
logprob?: number;
254+
}
255+
256+
export interface TextGenerationStreamBestOfSequence {
257+
/** Generated text */
258+
generated_text: string;
259+
/** Generation finish reason */
260+
finish_reason: TextGenerationStreamFinishReason;
261+
/** Number of generated tokens */
262+
generated_tokens: number;
263+
/** Sampling seed if sampling was activated */
264+
seed?: number;
265+
/** Prompt tokens */
266+
prefill: TextGenerationStreamPrefillToken[];
267+
/** Generated tokens */
268+
tokens: TextGenerationStreamToken[];
269+
}
270+
271+
export enum TextGenerationStreamFinishReason {
272+
/** number of generated tokens == `max_new_tokens` */
273+
Length = "length",
274+
/** the model generated its end of sequence token */
275+
EndOfSequenceToken = "eos_token",
276+
/** the model generated a text included in `stop_sequences` */
277+
StopSequence = "stop_sequence",
278+
}
279+
280+
export interface TextGenerationStreamDetails {
281+
/** Generation finish reason */
282+
finish_reason: TextGenerationStreamFinishReason;
283+
/** Number of generated tokens */
284+
generated_tokens: number;
285+
/** Sampling seed if sampling was activated */
286+
seed?: number;
287+
/** Prompt tokens */
288+
prefill: TextGenerationStreamPrefillToken[];
289+
/** */
290+
tokens: TextGenerationStreamToken[];
291+
/** Additional sequences when using the `best_of` parameter */
292+
best_of_sequences?: TextGenerationStreamBestOfSequence[];
293+
}
294+
295+
export interface TextGenerationStreamReturn {
296+
/** Generated token, one at a time */
297+
token: TextGenerationStreamToken;
298+
/**
299+
* Complete generated text
300+
* Only available when the generation is finished
301+
*/
302+
generated_text?: string;
303+
/**
304+
* Generation details
305+
* Only available when the generation is finished
306+
*/
307+
details?: TextGenerationStreamDetails;
308+
}
309+
226310
export type TokenClassificationArgs = Args & {
227311
/**
228312
* A string to be classified
@@ -615,6 +699,16 @@ export class HfInference {
615699
return res?.[0];
616700
}
617701

702+
/**
703+
* Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
704+
*/
705+
public async *textGenerationStream(
706+
args: TextGenerationArgs,
707+
options?: Options
708+
): AsyncGenerator<TextGenerationStreamReturn> {
709+
yield* this.streamingRequest<TextGenerationStreamReturn>(args, options);
710+
}
711+
618712
/**
619713
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
620714
*/
@@ -834,15 +928,21 @@ export class HfInference {
834928
return res;
835929
}
836930

837-
public async request<T>(
838-
args: Args & { data?: Blob | ArrayBuffer },
931+
/**
932+
* Helper that prepares request arguments
933+
*/
934+
private makeRequestOptions(
935+
args: Args & {
936+
data?: Blob | ArrayBuffer;
937+
stream?: boolean;
938+
},
839939
options?: Options & {
840940
binary?: boolean;
841941
blob?: boolean;
842942
/** For internal HF use, which is why it's not exposed in {@link Options} */
843943
includeCredentials?: boolean;
844944
}
845-
): Promise<T> {
945+
) {
846946
const mergedOptions = { ...this.defaultOptions, ...options };
847947
const { model, ...otherArgs } = args;
848948

@@ -867,7 +967,8 @@ export class HfInference {
867967
}
868968
}
869969

870-
const response = await fetch(`https://api-inference.huggingface.co/models/${model}`, {
970+
const url = `${HF_INFERENCE_API_BASE_URL}${model}`;
971+
const info: RequestInit = {
871972
headers,
872973
method: "POST",
873974
body: options?.binary
@@ -877,7 +978,22 @@ export class HfInference {
877978
options: mergedOptions,
878979
}),
879980
credentials: options?.includeCredentials ? "include" : "same-origin",
880-
});
981+
};
982+
983+
return { url, info, mergedOptions };
984+
}
985+
986+
public async request<T>(
987+
args: Args & { data?: Blob | ArrayBuffer },
988+
options?: Options & {
989+
binary?: boolean;
990+
blob?: boolean;
991+
/** For internal HF use, which is why it's not exposed in {@link Options} */
992+
includeCredentials?: boolean;
993+
}
994+
): Promise<T> {
995+
const { url, info, mergedOptions } = this.makeRequestOptions(args, options);
996+
const response = await fetch(url, info);
881997

882998
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
883999
return this.request(args, {
@@ -899,4 +1015,65 @@ export class HfInference {
8991015
}
9001016
return output;
9011017
}
1018+
1019+
/**
1020+
* Make request that uses server-sent events and returns response as a generator
1021+
*/
1022+
public async *streamingRequest<T>(
1023+
args: Args & { data?: Blob | ArrayBuffer },
1024+
options?: Options & {
1025+
binary?: boolean;
1026+
blob?: boolean;
1027+
/** For internal HF use, which is why it's not exposed in {@link Options} */
1028+
includeCredentials?: boolean;
1029+
}
1030+
): AsyncGenerator<T> {
1031+
const { url, info, mergedOptions } = this.makeRequestOptions({ ...args, stream: true }, options);
1032+
const response = await fetch(url, info);
1033+
1034+
if (mergedOptions.retry_on_error !== false && response.status === 503 && !mergedOptions.wait_for_model) {
1035+
return this.streamingRequest(args, {
1036+
...mergedOptions,
1037+
wait_for_model: true,
1038+
});
1039+
}
1040+
if (!response.ok) {
1041+
throw new Error(`Server response contains error: ${response.status}`);
1042+
}
1043+
if (response.headers.get("content-type") !== "text/event-stream") {
1044+
throw new Error(`Server does not support event stream content type`);
1045+
}
1046+
1047+
const reader = response.body.getReader();
1048+
const events: EventSourceMessage[] = [];
1049+
1050+
const onEvent = (event: EventSourceMessage) => {
1051+
// accumulate events in array
1052+
events.push(event);
1053+
};
1054+
1055+
const onChunk = getLines(
1056+
getMessages(
1057+
() => {},
1058+
() => {},
1059+
onEvent
1060+
)
1061+
);
1062+
1063+
try {
1064+
while (true) {
1065+
const { done, value } = await reader.read();
1066+
if (done) return;
1067+
onChunk(value);
1068+
while (events.length > 0) {
1069+
const event = events.shift();
1070+
if (event.data.length > 0) {
1071+
yield JSON.parse(event.data) as T;
1072+
}
1073+
}
1074+
}
1075+
} finally {
1076+
reader.releaseLock();
1077+
}
1078+
}
9021079
}

0 commit comments

Comments
 (0)