Skip to content

Commit aa7d1ca

Browse files
committed
support for replicate
1 parent d96a18e commit aa7d1ca

File tree

6 files changed

+76
-16
lines changed

6 files changed

+76
-16
lines changed

packages/inference/src/lib/makeRequestOptions.ts

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import { REPLICATE_API_BASE_URL, REPLICATE_MODEL_IDS } from "../providers/replicate";
12
import { SAMBANOVA_API_BASE_URL, SAMBANOVA_MODEL_IDS } from "../providers/sambanova";
23
import { TOGETHER_API_BASE_URL, TOGETHER_MODEL_IDS } from "../providers/together";
34
import { INFERENCE_PROVIDERS, type InferenceTask, type Options, type RequestArgs } from "../types";
@@ -64,6 +65,9 @@ export async function makeRequestOptions(
6465
throw new Error("Specifying an Inference provider requires an accessToken");
6566
}
6667
switch (provider) {
68+
case "replicate":
69+
model = REPLICATE_MODEL_IDS[model];
70+
break;
6771
case "sambanova":
6872
model = SAMBANOVA_MODEL_IDS[model];
6973
break;
@@ -90,6 +94,9 @@ export async function makeRequestOptions(
9094
if (dont_load_model) {
9195
headers["X-Load-Model"] = "0";
9296
}
97+
if (provider === "replicate") {
98+
headers["Prefer"] = "wait";
99+
}
93100

94101
let url = (() => {
95102
if (endpointUrl && isUrl(model)) {
@@ -115,6 +122,8 @@ export async function makeRequestOptions(
115122
} else {
116123
/// This is an external key
117124
switch (provider) {
125+
case "replicate":
126+
return `${REPLICATE_API_BASE_URL}/v1/models/${model}/predictions`;
118127
case "sambanova":
119128
return SAMBANOVA_API_BASE_URL;
120129
case "together":
@@ -151,7 +160,9 @@ export async function makeRequestOptions(
151160
body: binary
152161
? args.data
153162
: JSON.stringify({
154-
...(otherArgs.model && isUrl(otherArgs.model) ? omit(otherArgs, "model") : { ...otherArgs, model }),
163+
...((otherArgs.model && isUrl(otherArgs.model)) || provider === "replicate"
164+
? omit(otherArgs, "model")
165+
: { ...otherArgs, model }),
155166
}),
156167
...(credentials ? { credentials } : undefined),
157168
signal: options?.signal,
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import type { ModelId } from "../types";
2+
3+
export const REPLICATE_API_BASE_URL = "https://api.replicate.com";
4+
5+
/**
6+
* Same comment as in sambanova.ts
7+
*/
8+
type ReplicateId = string;
9+
10+
/**
11+
* curl -s \
12+
* -H "Authorization: Bearer $REPLICATE_API_TOKEN" \
13+
* https://api.replicate.com/v1/models
14+
*/
15+
export const REPLICATE_MODEL_IDS: Record<ModelId, ReplicateId> = {
16+
"black-forest-labs/FLUX.1-schnell": "black-forest-labs/flux-schnell",
17+
"ByteDance/SDXL-Lightning": "bytedance/sdxl-lightning-4step",
18+
};

packages/inference/src/tasks/custom/request.ts

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,19 @@ export async function request<T>(
2626
}
2727

2828
if (!response.ok) {
29-
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
29+
if (
30+
["application/json", "application/problem+json"].some(
31+
(contentType) => response.headers.get("Content-Type")?.startsWith(contentType)
32+
)
33+
) {
3034
const output = await response.json();
3135
if ([400, 422, 404, 500].includes(response.status) && options?.chatCompletion) {
3236
throw new Error(`Server ${args.model} does not seem to support chat completion. Error: ${output.error}`);
3337
}
3438
if (output.error) {
3539
throw new Error(JSON.stringify(output.error));
40+
} else {
41+
throw new Error(output);
3642
}
3743
}
3844
throw new Error("An error occurred while fetching the blob");

packages/inference/src/tasks/cv/textToImage.ts

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@ export type TextToImageArgs = BaseArgs & {
99
inputs: string;
1010

1111
/**
12-
* Same param but for external providers like Together
12+
* Same param but for external providers like Together, Replicate
1313
*/
1414
prompt?: string;
1515
response_format?: "base64";
16+
input?: {
17+
prompt: string;
18+
};
1619

1720
parameters?: {
1821
/**
@@ -38,14 +41,16 @@ export type TextToImageArgs = BaseArgs & {
3841
};
3942
};
4043

44+
export type TextToImageOutput = Blob;
45+
4146
interface Base64ImageGeneration {
42-
id: string;
43-
model: string;
4447
data: Array<{
4548
b64_json: string;
4649
}>;
4750
}
48-
export type TextToImageOutput = Blob;
51+
interface OutputUrlImageGeneration {
52+
output: string[];
53+
}
4954

5055
/**
5156
* This task reads some text input and outputs an image.
@@ -56,16 +61,26 @@ export async function textToImage(args: TextToImageArgs, options?: Options): Pro
5661
args.prompt = args.inputs;
5762
args.inputs = "";
5863
args.response_format = "base64";
64+
} else if (args.provider === "replicate") {
65+
args.input = { prompt: args.inputs };
66+
delete (args as unknown as { inputs: unknown }).inputs;
5967
}
60-
const res = await request<TextToImageOutput | Base64ImageGeneration>(args, {
68+
const res = await request<TextToImageOutput | Base64ImageGeneration | OutputUrlImageGeneration>(args, {
6169
...options,
6270
taskHint: "text-to-image",
6371
});
64-
if (res && typeof res === "object" && Array.isArray(res.data) && res.data[0].b64_json) {
65-
const base64Data = res.data[0].b64_json;
66-
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
67-
const blob = await base64Response.blob();
68-
return blob;
72+
if (res && typeof res === "object") {
73+
if ("data" in res && Array.isArray(res.data) && res.data[0].b64_json) {
74+
const base64Data = res.data[0].b64_json;
75+
const base64Response = await fetch(`data:image/jpeg;base64,${base64Data}`);
76+
const blob = await base64Response.blob();
77+
return blob;
78+
}
79+
if ("output" in res && Array.isArray(res.output)) {
80+
const urlResponse = await fetch(res.output[0]);
81+
const blob = await urlResponse.blob();
82+
return blob;
83+
}
6984
}
7085
const isValidOutput = res && res instanceof Blob;
7186
if (!isValidOutput) {

packages/inference/src/types.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ export interface Options {
4545

4646
export type InferenceTask = Exclude<PipelineType, "other">;
4747

48-
export const INFERENCE_PROVIDERS = ["sambanova", "together", "hf-inference"] as const;
48+
export const INFERENCE_PROVIDERS = ["replicate", "sambanova", "together", "hf-inference"] as const;
4949
export type InferenceProvider = (typeof INFERENCE_PROVIDERS)[number];
5050

5151
export interface BaseArgs {
@@ -54,19 +54,19 @@ export interface BaseArgs {
5454
*
5555
* Can be created for free in hf.co/settings/token
5656
*
57-
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together...
57+
* You can also pass an external Inference provider's key if you intend to call a compatible provider like Sambanova, Together, Replicate...
5858
*/
5959
accessToken?: string;
6060

6161
/**
62-
* The model to use.
62+
* The HF model to use.
6363
*
6464
* If not specified, will call huggingface.co/api/tasks to get the default model for the task.
6565
*
6666
* /!\ Legacy behavior allows this to be an URL, but this is deprecated and will be removed in the future.
6767
* Use the `endpointUrl` parameter instead.
6868
*/
69-
model?: string;
69+
model?: ModelId;
7070

7171
/**
7272
* The URL of the endpoint to use. If not specified, will call huggingface.co/api/tasks to get the default endpoint for the task.

packages/inference/test/HfInference.spec.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,16 @@ describe.concurrent(
826826
});
827827
expect(res).toBeInstanceOf(Blob);
828828
});
829+
830+
it("textToImage replicate", async () => {
831+
const hf = new HfInference(env.REPLICATE_KEY);
832+
const res = await hf.textToImage({
833+
model: "black-forest-labs/FLUX.1-schnell",
834+
provider: "replicate",
835+
inputs: "black forest gateau cake spelling out the words FLUX SCHNELL, tasty, food photography, dynamic shot",
836+
});
837+
expect(res).toBeInstanceOf(Blob);
838+
});
829839
},
830840
TIMEOUT
831841
);

0 commit comments

Comments
 (0)