@@ -35,6 +35,9 @@ export interface Args {
35
35
model ?: string ;
36
36
}
37
37
38
+ export type RequestArgs = Args &
39
+ ( { data ?: Blob | ArrayBuffer } | { inputs : unknown } ) & { parameters ?: Record < string , unknown > } ;
40
+
38
41
export type FillMaskArgs = Args & {
39
42
inputs : string ;
40
43
} ;
@@ -909,10 +912,7 @@ export class HfInference {
909
912
args : AutomaticSpeechRecognitionArgs ,
910
913
options ?: Options
911
914
) : Promise < AutomaticSpeechRecognitionReturn > {
912
- const res = await this . request < AutomaticSpeechRecognitionReturn > ( args , {
913
- ...options ,
914
- binary : true ,
915
- } ) ;
915
+ const res = await this . request < AutomaticSpeechRecognitionReturn > ( args , options ) ;
916
916
const isValidOutput = typeof res . text === "string" ;
917
917
if ( ! isValidOutput ) {
918
918
throw new TypeError ( "Invalid inference output: output must be of type <text: string>" ) ;
@@ -928,10 +928,7 @@ export class HfInference {
928
928
args : AudioClassificationArgs ,
929
929
options ?: Options
930
930
) : Promise < AudioClassificationReturn > {
931
- const res = await this . request < AudioClassificationReturn > ( args , {
932
- ...options ,
933
- binary : true ,
934
- } ) ;
931
+ const res = await this . request < AudioClassificationReturn > ( args , options ) ;
935
932
const isValidOutput =
936
933
Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
937
934
if ( ! isValidOutput ) {
@@ -948,10 +945,7 @@ export class HfInference {
948
945
args : ImageClassificationArgs ,
949
946
options ?: Options
950
947
) : Promise < ImageClassificationReturn > {
951
- const res = await this . request < ImageClassificationReturn > ( args , {
952
- ...options ,
953
- binary : true ,
954
- } ) ;
948
+ const res = await this . request < ImageClassificationReturn > ( args , options ) ;
955
949
const isValidOutput =
956
950
Array . isArray ( res ) && res . every ( ( x ) => typeof x . label === "string" && typeof x . score === "number" ) ;
957
951
if ( ! isValidOutput ) {
@@ -965,10 +959,7 @@ export class HfInference {
965
959
* Recommended model: facebook/detr-resnet-50
966
960
*/
967
961
public async objectDetection ( args : ObjectDetectionArgs , options ?: Options ) : Promise < ObjectDetectionReturn > {
968
- const res = await this . request < ObjectDetectionReturn > ( args , {
969
- ...options ,
970
- binary : true ,
971
- } ) ;
962
+ const res = await this . request < ObjectDetectionReturn > ( args , options ) ;
972
963
const isValidOutput =
973
964
Array . isArray ( res ) &&
974
965
res . every (
@@ -993,10 +984,7 @@ export class HfInference {
993
984
* Recommended model: facebook/detr-resnet-50-panoptic
994
985
*/
995
986
public async imageSegmentation ( args : ImageSegmentationArgs , options ?: Options ) : Promise < ImageSegmentationReturn > {
996
- const res = await this . request < ImageSegmentationReturn > ( args , {
997
- ...options ,
998
- binary : true ,
999
- } ) ;
987
+ const res = await this . request < ImageSegmentationReturn > ( args , options ) ;
1000
988
const isValidOutput =
1001
989
Array . isArray ( res ) &&
1002
990
res . every ( ( x ) => typeof x . label === "string" && typeof x . mask === "string" && typeof x . score === "number" ) ;
@@ -1013,10 +1001,7 @@ export class HfInference {
1013
1001
* Recommended model: stabilityai/stable-diffusion-2
1014
1002
*/
1015
1003
public async textToImage ( args : TextToImageArgs , options ?: Options ) : Promise < TextToImageReturn > {
1016
- const res = await this . request < TextToImageReturn > ( args , {
1017
- ...options ,
1018
- blob : true ,
1019
- } ) ;
1004
+ const res = await this . request < TextToImageReturn > ( args , options ) ;
1020
1005
const isValidOutput = res && res instanceof Blob ;
1021
1006
if ( ! isValidOutput ) {
1022
1007
throw new TypeError ( "Invalid inference output: output must be of type object & of instance Blob" ) ;
@@ -1028,25 +1013,18 @@ export class HfInference {
1028
1013
* This task reads some image input and outputs the text caption.
1029
1014
*/
1030
1015
public async imageToText ( args : ImageToTextArgs , options ?: Options ) : Promise < ImageToTextReturn > {
1031
- return (
1032
- await this . request < [ ImageToTextReturn ] > ( args , {
1033
- ...options ,
1034
- binary : true ,
1035
- } )
1036
- ) ?. [ 0 ] ;
1016
+ return ( await this . request < [ ImageToTextReturn ] > ( args , options ) ) ?. [ 0 ] ;
1037
1017
}
1038
1018
1039
1019
/**
1040
1020
* Helper that prepares request arguments
1041
1021
*/
1042
1022
private makeRequestOptions (
1043
- args : Args & {
1023
+ args : RequestArgs & {
1044
1024
data ?: Blob | ArrayBuffer ;
1045
1025
stream ?: boolean ;
1046
1026
} ,
1047
1027
options ?: Options & {
1048
- binary ?: boolean ;
1049
- blob ?: boolean ;
1050
1028
/** For internal HF use, which is why it's not exposed in {@link Options} */
1051
1029
includeCredentials ?: boolean ;
1052
1030
}
@@ -1059,11 +1037,11 @@ export class HfInference {
1059
1037
headers [ "Authorization" ] = `Bearer ${ this . apiKey } ` ;
1060
1038
}
1061
1039
1062
- if ( ! options ?. binary ) {
1063
- headers [ "Content-Type" ] = "application/json" ;
1064
- }
1040
+ const binary = "data" in args && ! ! args . data ;
1065
1041
1066
- if ( options ?. binary ) {
1042
+ if ( ! binary ) {
1043
+ headers [ "Content-Type" ] = "application/json" ;
1044
+ } else {
1067
1045
if ( mergedOptions . wait_for_model ) {
1068
1046
headers [ "X-Wait-For-Model" ] = "true" ;
1069
1047
}
@@ -1082,7 +1060,7 @@ export class HfInference {
1082
1060
const info : RequestInit = {
1083
1061
headers,
1084
1062
method : "POST" ,
1085
- body : options ?. binary
1063
+ body : binary
1086
1064
? args . data
1087
1065
: JSON . stringify ( {
1088
1066
...otherArgs ,
@@ -1094,11 +1072,12 @@ export class HfInference {
1094
1072
return { url, info, mergedOptions } ;
1095
1073
}
1096
1074
1075
+ /**
1076
+ * Primitive to make custom calls to the inference API
1077
+ */
1097
1078
public async request < T > (
1098
- args : Args & { data ?: Blob | ArrayBuffer } ,
1079
+ args : RequestArgs ,
1099
1080
options ?: Options & {
1100
- binary ?: boolean ;
1101
- blob ?: boolean ;
1102
1081
/** For internal HF use, which is why it's not exposed in {@link Options} */
1103
1082
includeCredentials ?: boolean ;
1104
1083
}
@@ -1113,34 +1092,29 @@ export class HfInference {
1113
1092
} ) ;
1114
1093
}
1115
1094
1116
- if ( options ?. blob ) {
1117
- if ( ! response . ok ) {
1118
- if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1119
- const output = await response . json ( ) ;
1120
- if ( output . error ) {
1121
- throw new Error ( output . error ) ;
1122
- }
1095
+ if ( ! response . ok ) {
1096
+ if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1097
+ const output = await response . json ( ) ;
1098
+ if ( output . error ) {
1099
+ throw new Error ( output . error ) ;
1123
1100
}
1124
- throw new Error ( "An error occurred while fetching the blob" ) ;
1125
1101
}
1126
- return ( await response . blob ( ) ) as T ;
1102
+ throw new Error ( "An error occurred while fetching the blob" ) ;
1127
1103
}
1128
1104
1129
- const output = await response . json ( ) ;
1130
- if ( output . error ) {
1131
- throw new Error ( output . error ) ;
1105
+ if ( response . headers . get ( "Content-Type" ) ?. startsWith ( "application/json" ) ) {
1106
+ return await response . json ( ) ;
1132
1107
}
1133
- return output ;
1108
+
1109
+ return ( await response . blob ( ) ) as T ;
1134
1110
}
1135
1111
1136
1112
/**
1137
- * Make request that uses server-sent events and returns response as a generator
1113
+ * Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
1138
1114
*/
1139
1115
public async * streamingRequest < T > (
1140
- args : Args & { data ?: Blob | ArrayBuffer } ,
1116
+ args : RequestArgs ,
1141
1117
options ?: Options & {
1142
- binary ?: boolean ;
1143
- blob ?: boolean ;
1144
1118
/** For internal HF use, which is why it's not exposed in {@link Options} */
1145
1119
includeCredentials ?: boolean ;
1146
1120
}
0 commit comments