1
1
import { toArray } from "./utils/to-array" ;
2
+ import type { EventSourceMessage } from "./vendor/fetch-event-source/parse" ;
3
+ import { getLines , getMessages } from "./vendor/fetch-event-source/parse" ;
4
+
5
+ const HF_INFERENCE_API_BASE_URL = "https://api-inference.huggingface.co/models/" ;
2
6
3
7
export interface Options {
4
8
/**
@@ -223,6 +227,86 @@ export interface TextGenerationReturn {
223
227
generated_text : string ;
224
228
}
225
229
230
+ export interface TextGenerationStreamToken {
231
+ /** Token ID from the model tokenizer */
232
+ id : number ;
233
+ /** Token text */
234
+ text : string ;
235
+ /** Logprob */
236
+ logprob : number ;
237
+ /**
238
+ * Is the token a special token
239
+ * Can be used to ignore tokens when concatenating
240
+ */
241
+ special : boolean ;
242
+ }
243
+
244
+ export interface TextGenerationStreamPrefillToken {
245
+ /** Token ID from the model tokenizer */
246
+ id : number ;
247
+ /** Token text */
248
+ text : string ;
249
+ /**
250
+ * Logprob
251
+ * Optional since the logprob of the first token cannot be computed
252
+ */
253
+ logprob ?: number ;
254
+ }
255
+
256
+ export interface TextGenerationStreamBestOfSequence {
257
+ /** Generated text */
258
+ generated_text : string ;
259
+ /** Generation finish reason */
260
+ finish_reason : TextGenerationStreamFinishReason ;
261
+ /** Number of generated tokens */
262
+ generated_tokens : number ;
263
+ /** Sampling seed if sampling was activated */
264
+ seed ?: number ;
265
+ /** Prompt tokens */
266
+ prefill : TextGenerationStreamPrefillToken [ ] ;
267
+ /** Generated tokens */
268
+ tokens : TextGenerationStreamToken [ ] ;
269
+ }
270
+
271
+ export enum TextGenerationStreamFinishReason {
272
+ /** number of generated tokens == `max_new_tokens` */
273
+ Length = "length" ,
274
+ /** the model generated its end of sequence token */
275
+ EndOfSequenceToken = "eos_token" ,
276
+ /** the model generated a text included in `stop_sequences` */
277
+ StopSequence = "stop_sequence" ,
278
+ }
279
+
280
+ export interface TextGenerationStreamDetails {
281
+ /** Generation finish reason */
282
+ finish_reason : TextGenerationStreamFinishReason ;
283
+ /** Number of generated tokens */
284
+ generated_tokens : number ;
285
+ /** Sampling seed if sampling was activated */
286
+ seed ?: number ;
287
+ /** Prompt tokens */
288
+ prefill : TextGenerationStreamPrefillToken [ ] ;
289
+ /** */
290
+ tokens : TextGenerationStreamToken [ ] ;
291
+ /** Additional sequences when using the `best_of` parameter */
292
+ best_of_sequences ?: TextGenerationStreamBestOfSequence [ ] ;
293
+ }
294
+
295
+ export interface TextGenerationStreamReturn {
296
+ /** Generated token, one at a time */
297
+ token : TextGenerationStreamToken ;
298
+ /**
299
+ * Complete generated text
300
+ * Only available when the generation is finished
301
+ */
302
+ generated_text ?: string ;
303
+ /**
304
+ * Generation details
305
+ * Only available when the generation is finished
306
+ */
307
+ details ?: TextGenerationStreamDetails ;
308
+ }
309
+
226
310
export type TokenClassificationArgs = Args & {
227
311
/**
228
312
* A string to be classified
@@ -615,6 +699,16 @@ export class HfInference {
615
699
return res ?. [ 0 ] ;
616
700
}
617
701
702
+ /**
703
+ * Use to continue text from a prompt. Same as `textGeneration` but returns generator that can be read one token at a time
704
+ */
705
+ public async * textGenerationStream (
706
+ args : TextGenerationArgs ,
707
+ options ?: Options
708
+ ) : AsyncGenerator < TextGenerationStreamReturn > {
709
+ yield * this . streamingRequest < TextGenerationStreamReturn > ( args , options ) ;
710
+ }
711
+
618
712
/**
619
713
* Usually used for sentence parsing, either grammatical, or Named Entity Recognition (NER) to understand keywords contained within text. Recommended model: dbmdz/bert-large-cased-finetuned-conll03-english
620
714
*/
@@ -834,15 +928,21 @@ export class HfInference {
834
928
return res ;
835
929
}
836
930
837
- public async request < T > (
838
- args : Args & { data ?: Blob | ArrayBuffer } ,
931
+ /**
932
+ * Helper that prepares request arguments
933
+ */
934
+ private makeRequestOptions (
935
+ args : Args & {
936
+ data ?: Blob | ArrayBuffer ;
937
+ stream ?: boolean ;
938
+ } ,
839
939
options ?: Options & {
840
940
binary ?: boolean ;
841
941
blob ?: boolean ;
842
942
/** For internal HF use, which is why it's not exposed in {@link Options} */
843
943
includeCredentials ?: boolean ;
844
944
}
845
- ) : Promise < T > {
945
+ ) {
846
946
const mergedOptions = { ...this . defaultOptions , ...options } ;
847
947
const { model, ...otherArgs } = args ;
848
948
@@ -867,7 +967,8 @@ export class HfInference {
867
967
}
868
968
}
869
969
870
- const response = await fetch ( `https://api-inference.huggingface.co/models/${ model } ` , {
970
+ const url = `${ HF_INFERENCE_API_BASE_URL } ${ model } ` ;
971
+ const info : RequestInit = {
871
972
headers,
872
973
method : "POST" ,
873
974
body : options ?. binary
@@ -877,7 +978,22 @@ export class HfInference {
877
978
options : mergedOptions ,
878
979
} ) ,
879
980
credentials : options ?. includeCredentials ? "include" : "same-origin" ,
880
- } ) ;
981
+ } ;
982
+
983
+ return { url, info, mergedOptions } ;
984
+ }
985
+
986
+ public async request < T > (
987
+ args : Args & { data ?: Blob | ArrayBuffer } ,
988
+ options ?: Options & {
989
+ binary ?: boolean ;
990
+ blob ?: boolean ;
991
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
992
+ includeCredentials ?: boolean ;
993
+ }
994
+ ) : Promise < T > {
995
+ const { url, info, mergedOptions } = this . makeRequestOptions ( args , options ) ;
996
+ const response = await fetch ( url , info ) ;
881
997
882
998
if ( mergedOptions . retry_on_error !== false && response . status === 503 && ! mergedOptions . wait_for_model ) {
883
999
return this . request ( args , {
@@ -899,4 +1015,65 @@ export class HfInference {
899
1015
}
900
1016
return output ;
901
1017
}
1018
+
1019
+ /**
1020
+ * Make request that uses server-sent events and returns response as a generator
1021
+ */
1022
+ public async * streamingRequest < T > (
1023
+ args : Args & { data ?: Blob | ArrayBuffer } ,
1024
+ options ?: Options & {
1025
+ binary ?: boolean ;
1026
+ blob ?: boolean ;
1027
+ /** For internal HF use, which is why it's not exposed in {@link Options} */
1028
+ includeCredentials ?: boolean ;
1029
+ }
1030
+ ) : AsyncGenerator < T > {
1031
+ const { url, info, mergedOptions } = this . makeRequestOptions ( { ...args , stream : true } , options ) ;
1032
+ const response = await fetch ( url , info ) ;
1033
+
1034
+ if ( mergedOptions . retry_on_error !== false && response . status === 503 && ! mergedOptions . wait_for_model ) {
1035
+ return this . streamingRequest ( args , {
1036
+ ...mergedOptions ,
1037
+ wait_for_model : true ,
1038
+ } ) ;
1039
+ }
1040
+ if ( ! response . ok ) {
1041
+ throw new Error ( `Server response contains error: ${ response . status } ` ) ;
1042
+ }
1043
+ if ( response . headers . get ( "content-type" ) !== "text/event-stream" ) {
1044
+ throw new Error ( `Server does not support event stream content type` ) ;
1045
+ }
1046
+
1047
+ const reader = response . body . getReader ( ) ;
1048
+ const events : EventSourceMessage [ ] = [ ] ;
1049
+
1050
+ const onEvent = ( event : EventSourceMessage ) => {
1051
+ // accumulate events in array
1052
+ events . push ( event ) ;
1053
+ } ;
1054
+
1055
+ const onChunk = getLines (
1056
+ getMessages (
1057
+ ( ) => { } ,
1058
+ ( ) => { } ,
1059
+ onEvent
1060
+ )
1061
+ ) ;
1062
+
1063
+ try {
1064
+ while ( true ) {
1065
+ const { done, value } = await reader . read ( ) ;
1066
+ if ( done ) return ;
1067
+ onChunk ( value ) ;
1068
+ while ( events . length > 0 ) {
1069
+ const event = events . shift ( ) ;
1070
+ if ( event . data . length > 0 ) {
1071
+ yield JSON . parse ( event . data ) as T ;
1072
+ }
1073
+ }
1074
+ }
1075
+ } finally {
1076
+ reader . releaseLock ( ) ;
1077
+ }
1078
+ }
902
1079
}
0 commit comments