Skip to content

Commit bfa0aa4

Browse files
[SD3-5 dreambooth lora] update model cards (#9749)
* improve readme * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent ab1b7b2 commit bfa0aa4

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ def save_model_card(
8686
validation_prompt=None,
8787
repo_folder=None,
8888
):
89+
if "large" in base_model:
90+
model_variant = "SD3.5-Large"
91+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
92+
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
93+
else:
94+
model_variant = "SD3"
95+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
96+
variant_tags = ["sd3", "sd3-diffusers"]
97+
8998
widget_dict = []
9099
if images is not None:
91100
for i, image in enumerate(images):
@@ -95,7 +104,7 @@ def save_model_card(
95104
)
96105

97106
model_description = f"""
98-
# SD3 DreamBooth LoRA - {repo_id}
107+
# {model_variant} DreamBooth LoRA - {repo_id}
99108
100109
<Gallery />
101110
@@ -120,7 +129,7 @@ def save_model_card(
120129
```py
121130
from diffusers import AutoPipelineForText2Image
122131
import torch
123-
pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers', torch_dtype=torch.float16).to('cuda')
132+
pipeline = AutoPipelineForText2Image.from_pretrained({base_model}, torch_dtype=torch.float16).to('cuda')
124133
pipeline.load_lora_weights('{repo_id}', weight_name='pytorch_lora_weights.safetensors')
125134
image = pipeline('{validation_prompt if validation_prompt else instance_prompt}').images[0]
126135
```
@@ -135,7 +144,7 @@ def save_model_card(
135144
136145
## License
137146
138-
Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE).
147+
Please adhere to the licensing terms as described [here]({license_url}).
139148
"""
140149
model_card = load_or_create_model_card(
141150
repo_id_or_path=repo_id,
@@ -151,11 +160,11 @@ def save_model_card(
151160
"diffusers-training",
152161
"diffusers",
153162
"lora",
154-
"sd3",
155-
"sd3-diffusers",
156163
"template:sd-lora",
157164
]
158165

166+
tags += variant_tags
167+
159168
model_card = populate_model_card(model_card, tags=tags)
160169
model_card.save(os.path.join(repo_folder, "README.md"))
161170

examples/dreambooth/train_dreambooth_sd3.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,15 @@ def save_model_card(
7777
validation_prompt=None,
7878
repo_folder=None,
7979
):
80+
if "large" in base_model:
81+
model_variant = "SD3.5-Large"
82+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
83+
variant_tags = ["sd3.5-large", "sd3.5", "sd3.5-diffusers"]
84+
else:
85+
model_variant = "SD3"
86+
license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
87+
variant_tags = ["sd3", "sd3-diffusers"]
88+
8089
widget_dict = []
8190
if images is not None:
8291
for i, image in enumerate(images):
@@ -86,7 +95,7 @@ def save_model_card(
8695
)
8796

8897
model_description = f"""
89-
# SD3 DreamBooth - {repo_id}
98+
# {model_variant} DreamBooth - {repo_id}
9099
91100
<Gallery />
92101
@@ -113,7 +122,7 @@ def save_model_card(
113122
114123
## License
115124
116-
Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`.
125+
Please adhere to the licensing terms as described `[here]({license_url})`.
117126
"""
118127
model_card = load_or_create_model_card(
119128
repo_id_or_path=repo_id,
@@ -128,10 +137,9 @@ def save_model_card(
128137
"text-to-image",
129138
"diffusers-training",
130139
"diffusers",
131-
"sd3",
132-
"sd3-diffusers",
133140
"template:sd-lora",
134141
]
142+
tags += variant_tags
135143

136144
model_card = populate_model_card(model_card, tags=tags)
137145
model_card.save(os.path.join(repo_folder, "README.md"))

0 commit comments

Comments
 (0)