diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 15e1774fe7..0a6485d254 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -664,23 +664,42 @@ export const gliner = (model: ModelData): string[] => [ model = GLiNER.from_pretrained("${model.id}")`, ]; -export const indextts = (model: ModelData): string[] => [ - `# Download model -from huggingface_hub import snapshot_download - -snapshot_download(${model.id}, local_dir="checkpoints") - -from indextts.infer import IndexTTS - -# Ensure config.yaml is present in the checkpoints directory -tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml") - -voice = "path/to/your/reference_voice.wav" # Path to the voice reference audio file -text = "Hello, how are you?" -output_path = "output_index.wav" - -tts.infer(voice, text, output_path)`, -]; +export const indextts = (model: ModelData): string[] => { + if (model.id === "IndexTeam/IndexTTS-2") { + return [ + `from huggingface_hub import snapshot_download + snapshot_download("${model.id}", local_dir="checkpoints") + + from indextts.infer_v2 import IndexTTS2 + tts = IndexTTS2(cfg_path="checkpoints/config.yaml", model_dir="checkpoints", use_fp16=False, use_cuda_kernel=False, use_deepspeed=False) + + text = "Translate for me, what is a surprise!" + tts.infer(spk_audio_prompt="voice_sample.wav", text=text, output_path="gen.wav", verbose=True)`, + ]; + } else if (model.id === "IndexTeam/Index-TTS") { + return [ + `from huggingface_hub import snapshot_download + snapshot_download("${model.id}", local_dir="checkpoints") + + from indextts.infer import IndexTTS + tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml") + + voice = "voice_sample.wav" + text = "Translate for me, what is a surprise!" + tts.infer(voice, text, "gen.wav")`, + ]; + } + + return [ + `from huggingface_hub import snapshot_download + snapshot_download("${model.id}", local_dir="checkpoints") + + from indextts.infer import IndexTTS + tts = IndexTTS(model_dir="checkpoints", cfg_path="checkpoints/config.yaml") + tts.infer("path/to/voice.wav", "Hello!", "gen.wav")`, + ]; + }; + export const htrflow = (model: ModelData): string[] => [ `# CLI usage