@@ -86,6 +86,15 @@ def save_model_card(
86
86
validation_prompt = None ,
87
87
repo_folder = None ,
88
88
):
89
+ if "large" in base_model :
90
+ model_variant = "SD3.5-Large"
91
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/LICENSE.md"
92
+ variant_tags = ["sd3.5-large" , "sd3.5" , "sd3.5-diffusers" ]
93
+ else :
94
+ model_variant = "SD3"
95
+ license_url = "https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE.md"
96
+ variant_tags = ["sd3" , "sd3-diffusers" ]
97
+
89
98
widget_dict = []
90
99
if images is not None :
91
100
for i , image in enumerate (images ):
@@ -95,7 +104,7 @@ def save_model_card(
95
104
)
96
105
97
106
model_description = f"""
98
- # SD3 DreamBooth LoRA - { repo_id }
107
+ # { model_variant } DreamBooth LoRA - { repo_id }
99
108
100
109
<Gallery />
101
110
@@ -120,7 +129,7 @@ def save_model_card(
120
129
```py
121
130
from diffusers import AutoPipelineForText2Image
122
131
import torch
123
- pipeline = AutoPipelineForText2Image.from_pretrained('stabilityai/stable-diffusion-3-medium-diffusers' , torch_dtype=torch.float16).to('cuda')
132
+ pipeline = AutoPipelineForText2Image.from_pretrained({ base_model } , torch_dtype=torch.float16).to('cuda')
124
133
pipeline.load_lora_weights('{ repo_id } ', weight_name='pytorch_lora_weights.safetensors')
125
134
image = pipeline('{ validation_prompt if validation_prompt else instance_prompt } ').images[0]
126
135
```
@@ -135,7 +144,7 @@ def save_model_card(
135
144
136
145
## License
137
146
138
- Please adhere to the licensing terms as described [here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE ).
147
+ Please adhere to the licensing terms as described [here]({ license_url } ).
139
148
"""
140
149
model_card = load_or_create_model_card (
141
150
repo_id_or_path = repo_id ,
@@ -151,11 +160,11 @@ def save_model_card(
151
160
"diffusers-training" ,
152
161
"diffusers" ,
153
162
"lora" ,
154
- "sd3" ,
155
- "sd3-diffusers" ,
156
163
"template:sd-lora" ,
157
164
]
158
165
166
+ tags += variant_tags
167
+
159
168
model_card = populate_model_card (model_card , tags = tags )
160
169
model_card .save (os .path .join (repo_folder , "README.md" ))
161
170
0 commit comments