@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
1256
1256
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
1257
1257
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
1258
1258
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
1259
+ config_url = None
1259
1260
1260
1261
# model_type = "v1"
1261
- config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
1262
+ if config_files is not None and "v1" in config_files :
1263
+ original_config_file = config_files ["v1" ]
1264
+ else :
1265
+ config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
1262
1266
1263
1267
if key_name_v2_1 in checkpoint and checkpoint [key_name_v2_1 ].shape [- 1 ] == 1024 :
1264
1268
# model_type = "v2"
1265
- config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
1266
-
1269
+ if config_files is not None and "v2" in config_files :
1270
+ original_config_file = config_files ["v2" ]
1271
+ else :
1272
+ config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
1267
1273
if global_step == 110000 :
1268
1274
# v2.1 needs to upcast attention
1269
1275
upcast_attention = True
1270
1276
elif key_name_sd_xl_base in checkpoint :
1271
1277
# only base xl has two text embedders
1272
- config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
1278
+ if config_files is not None and "xl" in config_files :
1279
+ original_config_file = config_files ["xl" ]
1280
+ else :
1281
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
1273
1282
elif key_name_sd_xl_refiner in checkpoint :
1274
1283
# only refiner xl has embedder and one text embedders
1275
- config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1276
-
1277
- original_config_file = BytesIO (requests .get (config_url ).content )
1284
+ if config_files is not None and "xl_refiner" in config_files :
1285
+ original_config_file = config_files ["xl_refiner" ]
1286
+ else :
1287
+ config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1288
+ if config_url is not None :
1289
+ original_config_file = BytesIO (requests .get (config_url ).content )
1278
1290
1279
1291
original_config = OmegaConf .load (original_config_file )
1280
1292
0 commit comments