Skip to content

Commit ce9bfd0

Browse files
authored
📝 Document streamingRequest and request (#147)
1 parent 945e604 commit ce9bfd0

File tree

3 files changed

+52
-59
lines changed

3 files changed

+52
-59
lines changed

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,6 @@ await inference.imageToText({
100100
model: 'nlpconnect/vit-gpt2-image-captioning',
101101
})
102102

103-
104103
// Using your own inference endpoint: https://hf.co/docs/inference-endpoints/
105104
const gpt2 = inference.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
106105
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

packages/inference/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,26 @@ await hf.imageToText({
170170
model: 'nlpconnect/vit-gpt2-image-captioning'
171171
})
172172

173+
// Custom call, for models with custom parameters / outputs
174+
await inference.request({
175+
model: 'my-custom-model',
176+
inputs: 'hello world',
177+
parameters: {
178+
custom_param: 'some magic',
179+
}
180+
})
181+
182+
// Custom streaming call, for models with custom parameters / outputs
183+
for await (const output of inference.streamingRequest({
184+
model: 'my-custom-model',
185+
inputs: 'hello world',
186+
parameters: {
187+
custom_param: 'some magic',
188+
}
189+
})) {
190+
...
191+
}
192+
173193
// Using your own inference endpoint: https://hf.co/docs/inference-endpoints/
174194
const gpt2 = hf.endpoint('https://xyz.eu-west-1.aws.endpoints.huggingface.cloud/gpt2');
175195
const { generated_text } = await gpt2.textGeneration({inputs: 'The answer to the universe is'});

packages/inference/src/HfInference.ts

Lines changed: 32 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,9 @@ export interface Args {
3535
model?: string;
3636
}
3737

38+
export type RequestArgs = Args &
39+
({ data?: Blob | ArrayBuffer } | { inputs: unknown }) & { parameters?: Record<string, unknown> };
40+
3841
export type FillMaskArgs = Args & {
3942
inputs: string;
4043
};
@@ -909,10 +912,7 @@ export class HfInference {
909912
args: AutomaticSpeechRecognitionArgs,
910913
options?: Options
911914
): Promise<AutomaticSpeechRecognitionReturn> {
912-
const res = await this.request<AutomaticSpeechRecognitionReturn>(args, {
913-
...options,
914-
binary: true,
915-
});
915+
const res = await this.request<AutomaticSpeechRecognitionReturn>(args, options);
916916
const isValidOutput = typeof res.text === "string";
917917
if (!isValidOutput) {
918918
throw new TypeError("Invalid inference output: output must be of type <text: string>");
@@ -928,10 +928,7 @@ export class HfInference {
928928
args: AudioClassificationArgs,
929929
options?: Options
930930
): Promise<AudioClassificationReturn> {
931-
const res = await this.request<AudioClassificationReturn>(args, {
932-
...options,
933-
binary: true,
934-
});
931+
const res = await this.request<AudioClassificationReturn>(args, options);
935932
const isValidOutput =
936933
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
937934
if (!isValidOutput) {
@@ -948,10 +945,7 @@ export class HfInference {
948945
args: ImageClassificationArgs,
949946
options?: Options
950947
): Promise<ImageClassificationReturn> {
951-
const res = await this.request<ImageClassificationReturn>(args, {
952-
...options,
953-
binary: true,
954-
});
948+
const res = await this.request<ImageClassificationReturn>(args, options);
955949
const isValidOutput =
956950
Array.isArray(res) && res.every((x) => typeof x.label === "string" && typeof x.score === "number");
957951
if (!isValidOutput) {
@@ -965,10 +959,7 @@ export class HfInference {
965959
* Recommended model: facebook/detr-resnet-50
966960
*/
967961
public async objectDetection(args: ObjectDetectionArgs, options?: Options): Promise<ObjectDetectionReturn> {
968-
const res = await this.request<ObjectDetectionReturn>(args, {
969-
...options,
970-
binary: true,
971-
});
962+
const res = await this.request<ObjectDetectionReturn>(args, options);
972963
const isValidOutput =
973964
Array.isArray(res) &&
974965
res.every(
@@ -993,10 +984,7 @@ export class HfInference {
993984
* Recommended model: facebook/detr-resnet-50-panoptic
994985
*/
995986
public async imageSegmentation(args: ImageSegmentationArgs, options?: Options): Promise<ImageSegmentationReturn> {
996-
const res = await this.request<ImageSegmentationReturn>(args, {
997-
...options,
998-
binary: true,
999-
});
987+
const res = await this.request<ImageSegmentationReturn>(args, options);
1000988
const isValidOutput =
1001989
Array.isArray(res) &&
1002990
res.every((x) => typeof x.label === "string" && typeof x.mask === "string" && typeof x.score === "number");
@@ -1013,10 +1001,7 @@ export class HfInference {
10131001
* Recommended model: stabilityai/stable-diffusion-2
10141002
*/
10151003
public async textToImage(args: TextToImageArgs, options?: Options): Promise<TextToImageReturn> {
1016-
const res = await this.request<TextToImageReturn>(args, {
1017-
...options,
1018-
blob: true,
1019-
});
1004+
const res = await this.request<TextToImageReturn>(args, options);
10201005
const isValidOutput = res && res instanceof Blob;
10211006
if (!isValidOutput) {
10221007
throw new TypeError("Invalid inference output: output must be of type object & of instance Blob");
@@ -1028,25 +1013,18 @@ export class HfInference {
10281013
* This task reads some image input and outputs the text caption.
10291014
*/
10301015
public async imageToText(args: ImageToTextArgs, options?: Options): Promise<ImageToTextReturn> {
1031-
return (
1032-
await this.request<[ImageToTextReturn]>(args, {
1033-
...options,
1034-
binary: true,
1035-
})
1036-
)?.[0];
1016+
return (await this.request<[ImageToTextReturn]>(args, options))?.[0];
10371017
}
10381018

10391019
/**
10401020
* Helper that prepares request arguments
10411021
*/
10421022
private makeRequestOptions(
1043-
args: Args & {
1023+
args: RequestArgs & {
10441024
data?: Blob | ArrayBuffer;
10451025
stream?: boolean;
10461026
},
10471027
options?: Options & {
1048-
binary?: boolean;
1049-
blob?: boolean;
10501028
/** For internal HF use, which is why it's not exposed in {@link Options} */
10511029
includeCredentials?: boolean;
10521030
}
@@ -1059,11 +1037,11 @@ export class HfInference {
10591037
headers["Authorization"] = `Bearer ${this.apiKey}`;
10601038
}
10611039

1062-
if (!options?.binary) {
1063-
headers["Content-Type"] = "application/json";
1064-
}
1040+
const binary = "data" in args && !!args.data;
10651041

1066-
if (options?.binary) {
1042+
if (!binary) {
1043+
headers["Content-Type"] = "application/json";
1044+
} else {
10671045
if (mergedOptions.wait_for_model) {
10681046
headers["X-Wait-For-Model"] = "true";
10691047
}
@@ -1082,7 +1060,7 @@ export class HfInference {
10821060
const info: RequestInit = {
10831061
headers,
10841062
method: "POST",
1085-
body: options?.binary
1063+
body: binary
10861064
? args.data
10871065
: JSON.stringify({
10881066
...otherArgs,
@@ -1094,11 +1072,12 @@ export class HfInference {
10941072
return { url, info, mergedOptions };
10951073
}
10961074

1075+
/**
1076+
* Primitive to make custom calls to the inference API
1077+
*/
10971078
public async request<T>(
1098-
args: Args & { data?: Blob | ArrayBuffer },
1079+
args: RequestArgs,
10991080
options?: Options & {
1100-
binary?: boolean;
1101-
blob?: boolean;
11021081
/** For internal HF use, which is why it's not exposed in {@link Options} */
11031082
includeCredentials?: boolean;
11041083
}
@@ -1113,34 +1092,29 @@ export class HfInference {
11131092
});
11141093
}
11151094

1116-
if (options?.blob) {
1117-
if (!response.ok) {
1118-
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
1119-
const output = await response.json();
1120-
if (output.error) {
1121-
throw new Error(output.error);
1122-
}
1095+
if (!response.ok) {
1096+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
1097+
const output = await response.json();
1098+
if (output.error) {
1099+
throw new Error(output.error);
11231100
}
1124-
throw new Error("An error occurred while fetching the blob");
11251101
}
1126-
return (await response.blob()) as T;
1102+
throw new Error("An error occurred while fetching the blob");
11271103
}
11281104

1129-
const output = await response.json();
1130-
if (output.error) {
1131-
throw new Error(output.error);
1105+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
1106+
return await response.json();
11321107
}
1133-
return output;
1108+
1109+
return (await response.blob()) as T;
11341110
}
11351111

11361112
/**
1137-
* Make request that uses server-sent events and returns response as a generator
1113+
* Primitive to make custom inference calls that expect server-sent events, and returns the response through a generator
11381114
*/
11391115
public async *streamingRequest<T>(
1140-
args: Args & { data?: Blob | ArrayBuffer },
1116+
args: RequestArgs,
11411117
options?: Options & {
1142-
binary?: boolean;
1143-
blob?: boolean;
11441118
/** For internal HF use, which is why it's not exposed in {@link Options} */
11451119
includeCredentials?: boolean;
11461120
}

0 commit comments

Comments
 (0)