Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions packages/tasks/src/model-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ export interface ModelData {
parameters?: Record<string, unknown>;
};
base_model?: string | string[];
instance_prompt?: string;
};
/**
* Library name
Expand Down
11 changes: 9 additions & 2 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import type { ModelData } from "./model-data";
import type { WidgetExampleTextInput } from "./widget-example";
import { LIBRARY_TASK_MAPPING } from "./library-to-tasks";

const TAG_CUSTOM_CODE = "custom_code";
Expand Down Expand Up @@ -70,6 +71,10 @@ function get_base_diffusers_model(model: ModelData): string {
return model.cardData?.base_model?.toString() ?? "fill-in-base-model";
}

function get_prompt_from_diffusers_model(model: ModelData): string | undefined {
return (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt ?? undefined;
}

export const bertopic = (model: ModelData): string[] => [
`from bertopic import BERTopic

Expand Down Expand Up @@ -129,12 +134,14 @@ depth = model.infer_image(raw_img) # HxW raw depth map in numpy
];
};

const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k";

const diffusers_default = (model: ModelData) => [
`from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("${model.id}")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
image = pipe(prompt).images[0]`,
];

Expand All @@ -153,7 +160,7 @@ const diffusers_lora = (model: ModelData) => [
pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
pipe.load_lora_weights("${model.id}")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}"
image = pipe(prompt).images[0]`,
];

Expand Down
Loading