diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index e03256f7d3..35f139bcd5 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -331,6 +331,19 @@ output = model.generate(text) sf.write("simple.mp3", output, 44100)`, ]; +export const dia2 = (model: ModelData): string[] => [ + `from dia2 import Dia2, GenerationConfig, SamplingConfig + +dia = Dia2.from_repo("${model.id}", device="cuda", dtype="bfloat16") +config = GenerationConfig( + cfg_scale=2.0, + audio=SamplingConfig(temperature=0.8, top_k=50), + use_cuda_graph=True, +) +result = dia.generate("[S1] Hello Dia2!", config=config, output_wav="hello.wav", verbose=True) +`, +]; + export const describe_anything = (model: ModelData): string[] => [ `# pip install git+https://github.com/NVlabs/describe-anything from huggingface_hub import snapshot_download diff --git a/packages/tasks/src/model-libraries.ts b/packages/tasks/src/model-libraries.ts index 16ec332a2a..d6ab1b3fff 100644 --- a/packages/tasks/src/model-libraries.ts +++ b/packages/tasks/src/model-libraries.ts @@ -293,6 +293,13 @@ export const MODEL_LIBRARIES_UI_ELEMENTS = { snippets: snippets.dia, filter: false, }, + dia2: { + prettyLabel: "Dia2", + repoName: "Dia2", + repoUrl: "https://github.com/nari-labs/dia2", + snippets: snippets.dia2, + filter: false, + }, "diff-interpretation-tuning": { prettyLabel: "Diff Interpretation Tuning", repoName: "Diff Interpretation Tuning",