diff --git a/packages/tasks/src/model-data.ts b/packages/tasks/src/model-data.ts index 53d66bfe15..7284533355 100644 --- a/packages/tasks/src/model-data.ts +++ b/packages/tasks/src/model-data.ts @@ -107,6 +107,7 @@ export interface ModelData { parameters?: Record; }; base_model?: string | string[]; + instance_prompt?: string; }; /** * Library name diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 876cc2a93a..653fb01757 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1,4 +1,5 @@ import type { ModelData } from "./model-data"; +import type { WidgetExampleTextInput } from "./widget-example"; import { LIBRARY_TASK_MAPPING } from "./library-to-tasks"; const TAG_CUSTOM_CODE = "custom_code"; @@ -8,6 +9,8 @@ function nameWithoutNamespace(modelId: string): string { return splitted.length === 1 ? splitted[0] : splitted[1]; } +const escapeStringForJson = (str: string): string => JSON.stringify(str); + //#region snippets export const adapters = (model: ModelData): string[] => [ @@ -70,6 +73,13 @@ function get_base_diffusers_model(model: ModelData): string { return model.cardData?.base_model?.toString() ?? "fill-in-base-model"; } +function get_prompt_from_diffusers_model(model: ModelData): string | undefined { + const prompt = (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; + if (prompt) { + return escapeStringForJson(prompt); + } +} + export const bertopic = (model: ModelData): string[] => [ `from bertopic import BERTopic @@ -129,12 +139,14 @@ depth = model.infer_image(raw_img) # HxW raw depth map in numpy ]; }; +const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; + const diffusers_default = (model: ModelData) => [ `from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; @@ -153,7 +165,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ];