From 66535f5955ded06948cbeca030102882b2cbcb34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 18 Sep 2024 03:47:57 -0500 Subject: [PATCH 01/12] Improve prompting for diffusers default snippets Addressing comments left on #907, specially for LoRAs --- packages/tasks/src/model-libraries-snippets.ts | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 876cc2a93a..9901d1e7ea 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -70,6 +70,12 @@ function get_base_diffusers_model(model: ModelData): string { return model.cardData?.base_model?.toString() ?? "fill-in-base-model"; } +function get_prompt_from_diffusers_model(model: ModelData): string { + return model.cardData?.widget?.[0]?.text?.toString() + ?? model.cardData?.instance_prompt?.toString() + ?? "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; +} + export const bertopic = (model: ModelData): string[] => [ `from bertopic import BERTopic @@ -134,7 +140,7 @@ const diffusers_default = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +prompt = "${get_prompt_from_diffusers_model(model)}" image = pipe(prompt).images[0]`, ]; @@ -153,7 +159,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" +prompt = "${get_prompt_from_diffusers_model(model)}" image = pipe(prompt).images[0]`, ]; From 324e720c631ab871a920c371178e3d5dfdd7613b Mon Sep 17 00:00:00 2001 From: multimodalart Date: Wed, 18 Sep 2024 11:10:17 +0200 Subject: [PATCH 02/12] add props to `ModelData` and lint --- packages/tasks/src/model-data.ts | 7 +++++++ packages/tasks/src/model-libraries-snippets.ts | 8 +++++--- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/packages/tasks/src/model-data.ts b/packages/tasks/src/model-data.ts index 53d66bfe15..008d8953aa 100644 --- a/packages/tasks/src/model-data.ts +++ b/packages/tasks/src/model-data.ts @@ -107,6 +107,13 @@ export interface ModelData { parameters?: Record; }; base_model?: string | string[]; + widget?: Array<{ + text: string; + output?: { + url: string; + }; + }>; + instance_prompt?: string; }; /** * Library name diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 9901d1e7ea..1805cc1a07 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -71,9 +71,11 @@ function get_base_diffusers_model(model: ModelData): string { } function get_prompt_from_diffusers_model(model: ModelData): string { - return model.cardData?.widget?.[0]?.text?.toString() - ?? model.cardData?.instance_prompt?.toString() - ?? "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; + return ( + model.cardData?.widget?.[0]?.text?.toString() ?? + model.cardData?.instance_prompt?.toString() ?? + "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" + ); } export const bertopic = (model: ModelData): string[] => [ From 2e5caafe752744d371ad5f982d67032c3d065acd Mon Sep 17 00:00:00 2001 From: multimodalart Date: Wed, 18 Sep 2024 11:27:40 +0200 Subject: [PATCH 03/12] Apply review comments --- packages/tasks/src/model-data.ts | 6 ------ packages/tasks/src/model-libraries-snippets.ts | 15 +++++++-------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/packages/tasks/src/model-data.ts b/packages/tasks/src/model-data.ts index 008d8953aa..7284533355 100644 --- a/packages/tasks/src/model-data.ts +++ b/packages/tasks/src/model-data.ts @@ -107,12 +107,6 @@ export interface ModelData { parameters?: Record; }; base_model?: string | string[]; - widget?: Array<{ - text: string; - output?: { - url: string; - }; - }>; instance_prompt?: string; }; /** diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 1805cc1a07..4090bb54c8 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -1,4 +1,5 @@ import type { ModelData } from "./model-data"; +import type { WidgetExampleTextInput } from "./widget-example"; import { LIBRARY_TASK_MAPPING } from "./library-to-tasks"; const TAG_CUSTOM_CODE = "custom_code"; @@ -70,12 +71,8 @@ function get_base_diffusers_model(model: ModelData): string { return model.cardData?.base_model?.toString() ?? "fill-in-base-model"; } -function get_prompt_from_diffusers_model(model: ModelData): string { - return ( - model.cardData?.widget?.[0]?.text?.toString() ?? - model.cardData?.instance_prompt?.toString() ?? - "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" - ); +function get_prompt_from_diffusers_model(model: ModelData): string | undefined { + return (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt ?? undefined; } export const bertopic = (model: ModelData): string[] => [ @@ -137,12 +134,14 @@ depth = model.infer_image(raw_img) # HxW raw depth map in numpy ]; }; +const diffusers_default_prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; + const diffusers_default = (model: ModelData) => [ `from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model)}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusers_default_prompt}" image = pipe(prompt).images[0]`, ]; @@ -161,7 +160,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model)}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusers_default_prompt}" image = pipe(prompt).images[0]`, ]; From 00f6607620cc0dc9d70ec2eddf22544dd1997031 Mon Sep 17 00:00:00 2001 From: multimodalart Date: Wed, 18 Sep 2024 11:29:13 +0200 Subject: [PATCH 04/12] follow internal variable naming convention --- packages/tasks/src/model-libraries-snippets.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 4090bb54c8..0a7d760033 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -134,14 +134,14 @@ depth = model.infer_image(raw_img) # HxW raw depth map in numpy ]; }; -const diffusers_default_prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; +const diffusersDefaultPrompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"; const diffusers_default = (model: ModelData) => [ `from diffusers import DiffusionPipeline pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model) ?? diffusers_default_prompt}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; @@ -160,7 +160,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model) ?? diffusers_default_prompt}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; From 6d5adcebc171ac0cc40cb137865b9d1e4505aebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Wed, 18 Sep 2024 04:33:33 -0500 Subject: [PATCH 05/12] Update packages/tasks/src/model-libraries-snippets.ts Co-authored-by: Mishig --- packages/tasks/src/model-libraries-snippets.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 0a7d760033..f95b57cf4b 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -72,7 +72,7 @@ function get_base_diffusers_model(model: ModelData): string { } function get_prompt_from_diffusers_model(model: ModelData): string | undefined { - return (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt ?? undefined; + return (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; } export const bertopic = (model: ModelData): string[] => [ From 9bf0cf3308842b98889c775980b29748f376fa22 Mon Sep 17 00:00:00 2001 From: multimodalart Date: Wed, 18 Sep 2024 12:08:46 +0200 Subject: [PATCH 06/12] escape quotes --- packages/tasks/src/model-libraries-snippets.ts | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 0a7d760033..d8f10831ea 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -9,6 +9,8 @@ function nameWithoutNamespace(modelId: string): string { return splitted.length === 1 ? splitted[0] : splitted[1]; } +const escapeQuotes = (str: string | undefined): string | undefined => str?.replace(/"/g, '\\"'); + //#region snippets export const adapters = (model: ModelData): string[] => [ @@ -141,7 +143,7 @@ const diffusers_default = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" +prompt = "${escapeQuotes(get_prompt_from_diffusers_model(model)) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; @@ -160,7 +162,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" +prompt = "${escapeQuotes(get_prompt_from_diffusers_model(model)) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; From 96250ca6e33f2a9b219f7703e4faed446a533e9b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 19 Sep 2024 05:11:02 -0500 Subject: [PATCH 07/12] Update packages/tasks/src/model-libraries-snippets.ts Co-authored-by: Mishig --- packages/tasks/src/model-libraries-snippets.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index fbec742123..158c442c70 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -9,7 +9,7 @@ function nameWithoutNamespace(modelId: string): string { return splitted.length === 1 ? splitted[0] : splitted[1]; } -const escapeQuotes = (str: string | undefined): string | undefined => str?.replace(/"/g, '\\"'); +const escapeQuotes = (str: string): string => JSON.stringify(str); //#region snippets From 15983c516a127cbbbfc3344df29eba2abaa71523 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 19 Sep 2024 05:11:10 -0500 Subject: [PATCH 08/12] Update packages/tasks/src/model-libraries-snippets.ts Co-authored-by: Mishig --- packages/tasks/src/model-libraries-snippets.ts | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 158c442c70..706648f17e 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -74,7 +74,10 @@ function get_base_diffusers_model(model: ModelData): string { } function get_prompt_from_diffusers_model(model: ModelData): string | undefined { - return (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; + const prompt = (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; + if(prompt){ + return escapeQuotes(prompt); + } } export const bertopic = (model: ModelData): string[] => [ From f720313fec4db696035415c5c03cb38b8ec35acd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 19 Sep 2024 05:11:16 -0500 Subject: [PATCH 09/12] Update packages/tasks/src/model-libraries-snippets.ts Co-authored-by: Mishig --- packages/tasks/src/model-libraries-snippets.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 706648f17e..1d5ae10f64 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -146,7 +146,7 @@ const diffusers_default = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${model.id}") -prompt = "${escapeQuotes(get_prompt_from_diffusers_model(model)) ?? diffusersDefaultPrompt}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; From 28d79d3e31227b5bcaa60928d75ba13d78ca6dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Thu, 19 Sep 2024 05:11:23 -0500 Subject: [PATCH 10/12] Update packages/tasks/src/model-libraries-snippets.ts Co-authored-by: Mishig --- packages/tasks/src/model-libraries-snippets.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 1d5ae10f64..89ae5298a8 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -165,7 +165,7 @@ const diffusers_lora = (model: ModelData) => [ pipe = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}") pipe.load_lora_weights("${model.id}") -prompt = "${escapeQuotes(get_prompt_from_diffusers_model(model)) ?? diffusersDefaultPrompt}" +prompt = "${get_prompt_from_diffusers_model(model) ?? diffusersDefaultPrompt}" image = pipe(prompt).images[0]`, ]; From 04a86690ab19fd978fad61fb5b188d1a87118312 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 19 Sep 2024 13:53:48 +0200 Subject: [PATCH 11/12] format --- packages/tasks/src/model-libraries-snippets.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 89ae5298a8..66ecbffdac 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -9,7 +9,7 @@ function nameWithoutNamespace(modelId: string): string { return splitted.length === 1 ? splitted[0] : splitted[1]; } -const escapeQuotes = (str: string): string => JSON.stringify(str); +const escapeQuotes = (str: string): string => JSON.stringify(str); //#region snippets @@ -75,9 +75,9 @@ function get_base_diffusers_model(model: ModelData): string { function get_prompt_from_diffusers_model(model: ModelData): string | undefined { const prompt = (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; - if(prompt){ - return escapeQuotes(prompt); - } + if (prompt) { + return escapeQuotes(prompt); + } } export const bertopic = (model: ModelData): string[] => [ From b7269c5ffab7e5a5495250a062efea30c984baf5 Mon Sep 17 00:00:00 2001 From: Mishig Davaadorj Date: Thu, 19 Sep 2024 13:55:44 +0200 Subject: [PATCH 12/12] rn to `escapeStringForJson` --- packages/tasks/src/model-libraries-snippets.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/tasks/src/model-libraries-snippets.ts b/packages/tasks/src/model-libraries-snippets.ts index 66ecbffdac..653fb01757 100644 --- a/packages/tasks/src/model-libraries-snippets.ts +++ b/packages/tasks/src/model-libraries-snippets.ts @@ -9,7 +9,7 @@ function nameWithoutNamespace(modelId: string): string { return splitted.length === 1 ? splitted[0] : splitted[1]; } -const escapeQuotes = (str: string): string => JSON.stringify(str); +const escapeStringForJson = (str: string): string => JSON.stringify(str); //#region snippets @@ -76,7 +76,7 @@ function get_base_diffusers_model(model: ModelData): string { function get_prompt_from_diffusers_model(model: ModelData): string | undefined { const prompt = (model.widgetData?.[0] as WidgetExampleTextInput).text ?? model.cardData?.instance_prompt; if (prompt) { - return escapeQuotes(prompt); + return escapeStringForJson(prompt); } }