Skip to content

Commit 59c307f

Browse files
bamps53sayakpaul
andauthored
Standardize model card for Controlnet (#6910)
* controlnet * style --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 159885a commit 59c307f

File tree

1 file changed

+21
-16
lines changed

1 file changed

+21
-16
lines changed

examples/controlnet/train_controlnet.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
)
5050
from diffusers.optimization import get_scheduler
5151
from diffusers.utils import check_min_version, is_wandb_available
52+
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5253
from diffusers.utils.import_utils import is_xformers_available
5354
from diffusers.utils.torch_utils import is_compiled_module
5455

@@ -207,27 +208,31 @@ def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=N
207208
image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png"))
208209
img_str += f"![images_{i})](./images_{i}.png)\n"
209210

210-
yaml = f"""
211-
---
212-
license: creativeml-openrail-m
213-
base_model: {base_model}
214-
tags:
215-
- stable-diffusion
216-
- stable-diffusion-diffusers
217-
- text-to-image
218-
- diffusers
219-
- controlnet
220-
inference: true
221-
---
222-
"""
223-
model_card = f"""
211+
model_description = f"""
224212
# controlnet-{repo_id}
225213
226214
These are controlnet weights trained on {base_model} with new type of conditioning.
227215
{img_str}
228216
"""
229-
with open(os.path.join(repo_folder, "README.md"), "w") as f:
230-
f.write(yaml + model_card)
217+
model_card = load_or_create_model_card(
218+
repo_id_or_path=repo_id,
219+
from_training=True,
220+
license="creativeml-openrail-m",
221+
base_model=base_model,
222+
model_description=model_description,
223+
inference=True,
224+
)
225+
226+
tags = [
227+
"stable-diffusion",
228+
"stable-diffusion-diffusers",
229+
"text-to-image",
230+
"diffusers",
231+
"controlnet",
232+
]
233+
model_card = populate_model_card(model_card, tags=tags)
234+
235+
model_card.save(os.path.join(repo_folder, "README.md"))
231236

232237

233238
def parse_args(input_args=None):

0 commit comments

Comments
 (0)