Skip to content

Commit 7512fc4

Browse files
vladmandicpatrickvonplaten
authored andcommitted
allow loading of sd models from safetensors without online lookups using local config files (#5019)
finish config_files implementation
1 parent 0c2f1cc commit 7512fc4

File tree

3 files changed

+22
-7
lines changed

3 files changed

+22
-7
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
pipe = download_from_original_stable_diffusion_ckpt(
155155
checkpoint_path_or_dict=args.checkpoint_path,
156156
original_config_file=args.original_config_file,
157+
config_files=args.config_files,
157158
image_size=args.image_size,
158159
prediction_type=args.prediction_type,
159160
model_type=args.pipeline_type,

src/diffusers/loaders.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
20982098
from .pipelines.stable_diffusion.convert_from_ckpt import download_from_original_stable_diffusion_ckpt
20992099

21002100
original_config_file = kwargs.pop("original_config_file", None)
2101+
config_files = kwargs.pop("config_files", None)
21012102
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
21022103
resume_download = kwargs.pop("resume_download", False)
21032104
force_download = kwargs.pop("force_download", False)
@@ -2215,6 +2216,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
22152216
vae=vae,
22162217
tokenizer=tokenizer,
22172218
original_config_file=original_config_file,
2219+
config_files=config_files,
22182220
)
22192221

22202222
if torch_dtype is not None:

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,25 +1256,37 @@ def download_from_original_stable_diffusion_ckpt(
12561256
key_name_v2_1 = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"
12571257
key_name_sd_xl_base = "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias"
12581258
key_name_sd_xl_refiner = "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias"
1259+
config_url = None
12591260

12601261
# model_type = "v1"
1261-
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
1262+
if config_files is not None and "v1" in config_files:
1263+
original_config_file = config_files["v1"]
1264+
else:
1265+
config_url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
12621266

12631267
if key_name_v2_1 in checkpoint and checkpoint[key_name_v2_1].shape[-1] == 1024:
12641268
# model_type = "v2"
1265-
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
1266-
1269+
if config_files is not None and "v2" in config_files:
1270+
original_config_file = config_files["v2"]
1271+
else:
1272+
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
12671273
if global_step == 110000:
12681274
# v2.1 needs to upcast attention
12691275
upcast_attention = True
12701276
elif key_name_sd_xl_base in checkpoint:
12711277
# only base xl has two text embedders
1272-
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
1278+
if config_files is not None and "xl" in config_files:
1279+
original_config_file = config_files["xl"]
1280+
else:
1281+
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_base.yaml"
12731282
elif key_name_sd_xl_refiner in checkpoint:
12741283
# only refiner xl has embedder and one text embedders
1275-
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1276-
1277-
original_config_file = BytesIO(requests.get(config_url).content)
1284+
if config_files is not None and "xl_refiner" in config_files:
1285+
original_config_file = config_files["xl_refiner"]
1286+
else:
1287+
config_url = "https://raw.githubusercontent.com/Stability-AI/generative-models/main/configs/inference/sd_xl_refiner.yaml"
1288+
if config_url is not None:
1289+
original_config_file = BytesIO(requests.get(config_url).content)
12781290

12791291
original_config = OmegaConf.load(original_config_file)
12801292

0 commit comments

Comments
 (0)