Skip to content

Commit a0422ed

Browse files
[From Single File] Allow vae to be loaded (#4242)
* Allow vae to be loaded * up
1 parent 3dd3393 commit a0422ed

File tree

2 files changed

+11
-2
lines changed

2 files changed

+11
-2
lines changed

src/diffusers/loaders.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1410,6 +1410,9 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
14101410
An instance of `CLIPTextModel` to use, specifically the
14111411
[clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. If this
14121412
parameter is `None`, the function loads a new instance of `CLIPTextModel` by itself if needed.
1413+
vae (`AutoencoderKL`, *optional*, defaults to `None`):
1414+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
1415+
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
14131416
tokenizer ([`~transformers.CLIPTokenizer`], *optional*, defaults to `None`):
14141417
An instance of `CLIPTokenizer` to use. If this parameter is `None`, the function loads a new instance
14151418
of `CLIPTokenizer` by itself if needed.
@@ -1458,6 +1461,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
14581461
load_safety_checker = kwargs.pop("load_safety_checker", True)
14591462
prediction_type = kwargs.pop("prediction_type", None)
14601463
text_encoder = kwargs.pop("text_encoder", None)
1464+
vae = kwargs.pop("vae", None)
14611465
controlnet = kwargs.pop("controlnet", None)
14621466
tokenizer = kwargs.pop("tokenizer", None)
14631467

@@ -1548,6 +1552,7 @@ def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
15481552
load_safety_checker=load_safety_checker,
15491553
prediction_type=prediction_type,
15501554
text_encoder=text_encoder,
1555+
vae=vae,
15511556
tokenizer=tokenizer,
15521557
)
15531558

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,6 +1107,7 @@ def download_from_original_stable_diffusion_ckpt(
11071107
pipeline_class: DiffusionPipeline = None,
11081108
local_files_only=False,
11091109
vae_path=None,
1110+
vae=None,
11101111
text_encoder=None,
11111112
tokenizer=None,
11121113
) -> DiffusionPipeline:
@@ -1156,6 +1157,9 @@ def download_from_original_stable_diffusion_ckpt(
11561157
The pipeline class to use. Pass `None` to determine automatically.
11571158
local_files_only (`bool`, *optional*, defaults to `False`):
11581159
Whether or not to only look at local files (i.e., do not try to download the model).
1160+
vae (`AutoencoderKL`, *optional*, defaults to `None`):
1161+
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. If
1162+
this parameter is `None`, the function will load a new instance of [CLIP] by itself, if needed.
11591163
text_encoder (`CLIPTextModel`, *optional*, defaults to `None`):
11601164
An instance of [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel)
11611165
to use, specifically the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)
@@ -1361,7 +1365,7 @@ def download_from_original_stable_diffusion_ckpt(
13611365
unet.load_state_dict(converted_unet_checkpoint)
13621366

13631367
# Convert the VAE model.
1364-
if vae_path is None:
1368+
if vae_path is None and vae is None:
13651369
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
13661370
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
13671371

@@ -1385,7 +1389,7 @@ def download_from_original_stable_diffusion_ckpt(
13851389
set_module_tensor_to_device(vae, param_name, "cpu", value=param)
13861390
else:
13871391
vae.load_state_dict(converted_vae_checkpoint)
1388-
else:
1392+
elif vae is None:
13891393
vae = AutoencoderKL.from_pretrained(vae_path)
13901394

13911395
if model_type == "FrozenOpenCLIPEmbedder":

0 commit comments

Comments
 (0)