Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
convert_unet_state_dict_to_peft,
is_wandb_available,
)
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module

Expand Down Expand Up @@ -100,30 +101,27 @@ def determine_scheduler_type(pretrained_model_name_or_path, revision):
def save_model_card(
repo_id: str,
use_dora: bool,
images=None,
base_model=str,
train_text_encoder=False,
train_text_encoder_ti=False,
images: list = None,
base_model: str = None,
train_text_encoder: bool = False,
train_text_encoder_ti: bool = False,
token_abstraction_dict=None,
instance_prompt=str,
validation_prompt=str,
instance_prompt: str = None,
validation_prompt: str = None,
repo_folder=None,
vae_path=None,
):
img_str = "widget:\n"
lora = "lora" if not use_dora else "dora"
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"""
- text: '{validation_prompt if validation_prompt else ' ' }'
output:
url:
"image_{i}.png"
"""
if not images:
img_str += f"""
- text: '{instance_prompt}'
"""
widget_dict = []
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"image_{i}.png"}}
)
else:
widget_dict.append(
{"text": instance_prompt}
)
embeddings_filename = f"{repo_folder}_emb"
instance_prompt_webui = re.sub(r"<s\d+>", "", re.sub(r"<s\d+>", embeddings_filename, instance_prompt, count=1))
ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"<s\d+>", instance_prompt))
Expand Down Expand Up @@ -160,23 +158,7 @@ def save_model_card(
to trigger concept `{key}` → use `{tokens}` in your prompt \n
"""

yaml = f"""---
tags:
- stable-diffusion-xl
- stable-diffusion-xl-diffusers
- diffusers-training
- text-to-image
- diffusers
- {lora}
- template:sd-lora
{img_str}
base_model: {base_model}
instance_prompt: {instance_prompt}
license: openrail++
---
"""

model_card = f"""
model_description = f"""
# SDXL LoRA DreamBooth - {repo_id}

<Gallery />
Expand Down Expand Up @@ -224,8 +206,25 @@ def save_model_card(
Special VAE used for training: {vae_path}.

"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="openrail++",
base_model=base_model,
prompt=instance_prompt,
model_description=model_description,
widget=widget_dict,
)
tags = [
"text-to-image",
"stable-diffusion-xl",
"stable-diffusion-xl-diffusers",
"text-to-image",
"diffusers",
"lora",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

last thing 😊- we need to switch "lora" -> lora in the tags here (to use the var defined in line 114)

"template:sd-lora",
]
model_card = populate_model_card(model_card, tags=tags)


def import_model_class_from_model_name_or_path(
Expand Down