Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions packages/tasks/src/model-data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ export interface ModelData {
parameters?: Record<string, unknown>;
};
base_model?: string | string[];
widget?: Array<{
text: string;
output?: {
url: string;
};
}>;
instance_prompt?: string;
};
/**
* Library name
Expand Down
12 changes: 10 additions & 2 deletions packages/tasks/src/model-libraries-snippets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,14 @@ function get_base_diffusers_model(model: ModelData): string {
return model.cardData?.base_model?.toString() ?? "fill-in-base-model";
}

function get_prompt_from_diffusers_model(model: ModelData): string {
return (
model.cardData?.widget?.[0]?.text?.toString() ??
model.cardData?.instance_prompt?.toString() ??
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
);
}

export const bertopic = (model: ModelData): string[] => [
`from bertopic import BERTopic

Expand Down Expand Up @@ -134,7 +142,7 @@ const diffusers_default = (model: ModelData) => [

pipe = DiffusionPipeline.from_pretrained("${model.id}")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
prompt = "${get_prompt_from_diffusers_model(model)}"
image = pipe(prompt).images[0]`,
];

Expand All @@ -153,7 +161,7 @@ const diffusers_lora = (model: ModelData) => [
pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
pipe.load_lora_weights("${model.id}")

prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
prompt = "${get_prompt_from_diffusers_model(model)}"
image = pipe(prompt).images[0]`,
];

Expand Down
Loading