Skip to content

Commit c6a1c7e

Browse files
Vaibhavs10hanouticelinaSBrandeis
authored
add support for Kokoro via Replicate (huggingface#1153)
Co-authored-by: Celina Hanouti <[email protected]> Co-authored-by: Simon Brandeis <[email protected]>
1 parent 7ee542d commit c6a1c7e

File tree

4 files changed

+30
-10
lines changed

4 files changed

+30
-10
lines changed

packages/inference/src/providers/replicate.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ export const REPLICATE_SUPPORTED_MODEL_IDS: ProviderMapping<ReplicateId> = {
2121
"stability-ai/sdxl:7762fd07cf82c948538e41f63f77d685e02b063e37e496e96eefd46c929f9bdc",
2222
},
2323
"text-to-speech": {
24-
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:39a59319327b27327fa3095149c5a746e7f2aee18c75055c3368237a6503cd26",
24+
"OuteAI/OuteTTS-0.3-500M": "jbilcke/oute-tts:3c645149db020c85d080e2f8cfe482a0e68189a922cde964fa9e80fb179191f3",
25+
"hexgrad/Kokoro-82M": "jaaari/kokoro-82m:dfdf537ba482b029e0a761699e6f55e9162cfd159270bfe0e44857caa5f275a6",
2526
},
2627
"text-to-video": {
2728
"genmo/mochi-1-preview": "genmoai/mochi-1:1944af04d098ef69bed7f9d335d102e652203f268ec4aaa2d836f6217217e460",

packages/inference/src/tasks/audio/textToSpeech.ts

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import type { TextToSpeechInput } from "@huggingface/tasks";
22
import { InferenceOutputError } from "../../lib/InferenceOutputError";
33
import type { BaseArgs, Options } from "../../types";
4+
import { omit } from "../../utils/omit";
45
import { request } from "../custom/request";
5-
66
type TextToSpeechArgs = BaseArgs & TextToSpeechInput;
77

88
interface OutputUrlTextToSpeechGeneration {
@@ -13,7 +13,16 @@ interface OutputUrlTextToSpeechGeneration {
1313
* Recommended model: espnet/kan-bayashi_ljspeech_vits
1414
*/
1515
export async function textToSpeech(args: TextToSpeechArgs, options?: Options): Promise<Blob> {
16-
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(args, {
16+
// Replicate models expects "text" instead of "inputs"
17+
const payload =
18+
args.provider === "replicate"
19+
? {
20+
...omit(args, ["inputs", "parameters"]),
21+
...args.parameters,
22+
text: args.inputs,
23+
}
24+
: args;
25+
const res = await request<Blob | OutputUrlTextToSpeechGeneration>(payload, {
1726
...options,
1827
taskHint: "text-to-speech",
1928
});

packages/inference/src/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
import type { PipelineType } from "@huggingface/tasks";
2-
import type { ChatCompletionInput } from "@huggingface/tasks";
1+
import type { ChatCompletionInput, PipelineType } from "@huggingface/tasks";
32

43
/**
54
* HF model id, like "meta-llama/Llama-3.3-70B-Instruct"
@@ -88,6 +87,7 @@ export type RequestArgs = BaseArgs &
8887
| { data: Blob | ArrayBuffer }
8988
| { inputs: unknown }
9089
| { prompt: string }
90+
| { text: string }
9191
| { audio_url: string }
9292
| ChatCompletionInput
9393
) & {

packages/inference/test/HfInference.spec.ts

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
import { expect, it, describe, assert } from "vitest";
1+
import { assert, describe, expect, it } from "vitest";
22

33
import type { ChatCompletionStreamOutput } from "@huggingface/tasks";
44

55
import { chatCompletion, FAL_AI_SUPPORTED_MODEL_IDS, HfInference } from "../src";
6-
import "./vcr";
7-
import { readTestFile } from "./test-files";
86
import { textToVideo } from "../src/tasks/cv/textToVideo";
7+
import { readTestFile } from "./test-files";
8+
import "./vcr";
99

1010
const TIMEOUT = 60000 * 3;
1111
const env = import.meta.env;
@@ -939,11 +939,21 @@ describe.concurrent("HfInference", () => {
939939
expect(res).toBeInstanceOf(Blob);
940940
});
941941

942-
it("textToSpeech OuteTTS", async () => {
942+
it.skip("textToSpeech OuteTTS - usually Cold", async () => {
943943
const res = await client.textToSpeech({
944944
model: "OuteAI/OuteTTS-0.3-500M",
945945
provider: "replicate",
946-
inputs: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters",
946+
text: "OuteTTS is a frontier TTS model for its size of 1 Billion parameters",
947+
});
948+
949+
expect(res).toBeInstanceOf(Blob);
950+
});
951+
952+
it("textToSpeech Kokoro", async () => {
953+
const res = await client.textToSpeech({
954+
model: "hexgrad/Kokoro-82M",
955+
provider: "replicate",
956+
text: "Kokoro is a frontier TTS model for its size of 1 Billion parameters",
947957
});
948958

949959
expect(res).toBeInstanceOf(Blob);

0 commit comments

Comments
 (0)