@@ -86,6 +86,15 @@ def save_model_card(
8686 validation_prompt = None ,
8787 repo_folder = None ,
8888):
89+ if "large" in base_model :
90+ model_variant = "SD3.5-Large"
91+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
92+ variant_tags = ["sd3.5-large" , "sd3.5" , "sd3.5-diffusers" ]
93+ else :
94+ model_variant = "SD3"
95+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
96+ variant_tags = ["sd3" , "sd3-diffusers" ]
97+
8998 widget_dict = []
9099 if images is not None :
91100 for i , image in enumerate (images ):
@@ -95,7 +104,7 @@ def save_model_card(
95104 )
96105
97106 model_description = f"""
98- # SD3 DreamBooth LoRA - { repo_id }
107+ # { model_variant } DreamBooth LoRA - { repo_id }
99108
100109<Gallery />
101110
@@ -120,7 +129,7 @@ def save_model_card(
120129```py
121130from diffusers import AutoPipelineForText2Image
122131import torch
123- pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers' , torch_dtype=torch.float16).to('cuda')
132+ pipeline = AutoPipelineForText2Image.from_pretrained({ base_model } , torch_dtype=torch.float16).to('cuda')
124133pipeline.load_lora_weights('{ repo_id } ', weight_name='pytorch_lora_weights.safetensors')
125134image = pipeline('{ validation_prompt if validation_prompt else instance_prompt } ').images[0]
126135```
@@ -135,7 +144,7 @@ def save_model_card(
135144
136145## License
137146
138- Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE ).
147+ Please adhere to the licensing terms as described [here]({ license_url } ).
139148"""
140149 model_card = load_or_create_model_card (
141150 repo_id_or_path = repo_id ,
@@ -151,11 +160,11 @@ def save_model_card(
151160 "diffusers-training" ,
152161 "diffusers" ,
153162 "lora" ,
154- "sd3" ,
155- "sd3-diffusers" ,
156163 "template:sd-lora" ,
157164 ]
158165
166+ tags += variant_tags
167+
159168 model_card = populate_model_card (model_card , tags = tags )
160169 model_card .save (os .path .join (repo_folder , "README.md" ))
161170
0 commit comments