Skip to content

Commit 0b90051

Browse files
committed
add vae encoder node
1 parent b305c77 commit 0b90051

File tree

6 files changed

+128
-38
lines changed

6 files changed

+128
-38
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -899,12 +899,6 @@ def prepare_latents(
899899
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
900900
)
901901

902-
latents_mean = latents_std = None
903-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
904-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
905-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
906-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
907-
908902
# Offload text encoder if `enable_model_cpu_offload` was enabled
909903
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
910904
self.text_encoder_2.to("cpu")
@@ -918,6 +912,11 @@ def prepare_latents(
918912
init_latents = image
919913

920914
else:
915+
latents_mean = latents_std = None
916+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
917+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
918+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
919+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
921920
# make sure the VAE is in float32 mode, as it overflows in float16
922921
if self.vae.config.force_upcast:
923922
image = image.float()

src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -607,12 +607,6 @@ def prepare_latents(
607607
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
608608
)
609609

610-
latents_mean = latents_std = None
611-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
612-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
613-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
614-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
615-
616610
# Offload text encoder if `enable_model_cpu_offload` was enabled
617611
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
618612
self.text_encoder_2.to("cpu")
@@ -626,6 +620,11 @@ def prepare_latents(
626620
init_latents = image
627621

628622
else:
623+
latents_mean = latents_std = None
624+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
625+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
626+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
627+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
629628
# make sure the VAE is in float32 mode, as it overflows in float16
630629
if self.vae.config.force_upcast:
631630
image = image.float()

src/diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -905,12 +905,6 @@ def prepare_latents(
905905
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
906906
)
907907

908-
latents_mean = latents_std = None
909-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
910-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
911-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
912-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
913-
914908
# Offload text encoder if `enable_model_cpu_offload` was enabled
915909
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
916910
self.text_encoder_2.to("cpu")
@@ -924,6 +918,11 @@ def prepare_latents(
924918
init_latents = image
925919

926920
else:
921+
latents_mean = latents_std = None
922+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
923+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
924+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
925+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
927926
# make sure the VAE is in float32 mode, as it overflows in float16
928927
if self.vae.config.force_upcast:
929928
image = image.float()

src/diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -691,12 +691,6 @@ def prepare_latents(
691691
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
692692
)
693693

694-
latents_mean = latents_std = None
695-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
696-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
697-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
698-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
699-
700694
# Offload text encoder if `enable_model_cpu_offload` was enabled
701695
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
702696
self.text_encoder_2.to("cpu")
@@ -710,6 +704,11 @@ def prepare_latents(
710704
init_latents = image
711705

712706
else:
707+
latents_mean = latents_std = None
708+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
709+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
710+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
711+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
713712
# make sure the VAE is in float32 mode, as it overflows in float16
714713
if self.vae.config.force_upcast:
715714
image = image.float()

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,6 @@ def prepare_latents(
682682
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
683683
)
684684

685-
latents_mean = latents_std = None
686-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
687-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
688-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
689-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
690-
691685
# Offload text encoder if `enable_model_cpu_offload` was enabled
692686
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
693687
self.text_encoder_2.to("cpu")
@@ -701,6 +695,11 @@ def prepare_latents(
701695
init_latents = image
702696

703697
else:
698+
latents_mean = latents_std = None
699+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
700+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
701+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
702+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
704703
# make sure the VAE is in float32 mode, as it overflows in float16
705704
if self.vae.config.force_upcast:
706705
image = image.float()

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 103 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,102 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
325325
return pipeline, state
326326

327327

328+
class StableDiffusionXLVAEEncoderStep(PipelineBlock):
329+
expected_components = ["vae"]
330+
expected_auxiliaries = ["image_processor"]
331+
332+
@property
333+
def inputs(self) -> List[Tuple[str, Any]]:
334+
return [
335+
("image", None),
336+
("generator", None),
337+
("height", None),
338+
("width", None),
339+
("device", None),
340+
("dtype", None),
341+
]
342+
343+
@property
344+
def intermediates_inputs(self) -> List[str]:
345+
return ["batch_size"]
346+
347+
@property
348+
def intermediates_outputs(self) -> List[str]:
349+
return ["image_latents"]
350+
351+
def __init__(self, vae=None):
352+
super().__init__(vae=vae)
353+
self.image_processor = VaeImageProcessor()
354+
self.auxiliaries["image_processor"] = self.image_processor
355+
356+
@torch.no_grad()
357+
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
358+
image = state.get_input("image")
359+
generator = state.get_input("generator")
360+
height = state.get_input("height")
361+
width = state.get_input("width")
362+
device = state.get_input("device")
363+
dtype = state.get_input("dtype")
364+
365+
batch_size = state.get_intermediate("batch_size")
366+
367+
if device is None:
368+
device = pipeline._execution_device
369+
if dtype is None:
370+
dtype = pipeline.vae.dtype
371+
372+
image = pipeline.image_processor.preprocess(image, height=height, width=width)
373+
image = image.to(device=device, dtype=dtype)
374+
375+
latents_mean = latents_std = None
376+
if hasattr(pipeline.vae.config, "latents_mean") and pipeline.vae.config.latents_mean is not None:
377+
latents_mean = torch.tensor(pipeline.vae.config.latents_mean).view(1, 4, 1, 1)
378+
if hasattr(pipeline.vae.config, "latents_std") and pipeline.vae.config.latents_std is not None:
379+
latents_std = torch.tensor(pipeline.vae.config.latents_std).view(1, 4, 1, 1)
380+
381+
# make sure the VAE is in float32 mode, as it overflows in float16
382+
if pipeline.vae.config.force_upcast:
383+
image = image.float()
384+
pipeline.vae.to(dtype=torch.float32)
385+
386+
if isinstance(generator, list) and len(generator) != batch_size:
387+
raise ValueError(
388+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
389+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
390+
)
391+
392+
elif isinstance(generator, list):
393+
if image.shape[0] < batch_size and batch_size % image.shape[0] == 0:
394+
image = torch.cat([image] * (batch_size // image.shape[0]), dim=0)
395+
elif image.shape[0] < batch_size and batch_size % image.shape[0] != 0:
396+
raise ValueError(
397+
f"Cannot duplicate `image` of batch size {image.shape[0]} to effective batch_size {batch_size} "
398+
)
399+
400+
init_latents = [
401+
retrieve_latents(pipeline.vae.encode(image[i : i + 1]), generator=generator[i])
402+
for i in range(batch_size)
403+
]
404+
init_latents = torch.cat(init_latents, dim=0)
405+
else:
406+
init_latents = retrieve_latents(pipeline.vae.encode(image), generator=generator)
407+
408+
if pipeline.vae.config.force_upcast:
409+
pipeline.vae.to(dtype)
410+
411+
init_latents = init_latents.to(dtype)
412+
if latents_mean is not None and latents_std is not None:
413+
latents_mean = latents_mean.to(device=device, dtype=dtype)
414+
latents_std = latents_std.to(device=device, dtype=dtype)
415+
init_latents = (init_latents - latents_mean) * pipeline.vae.config.scaling_factor / latents_std
416+
else:
417+
init_latents = pipeline.vae.config.scaling_factor * init_latents
418+
419+
state.add_intermediate("image_latents", init_latents)
420+
421+
return pipeline, state
422+
423+
328424
class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock):
329425
expected_components = ["scheduler"]
330426

@@ -498,9 +594,9 @@ def __call__(self, pipeline: DiffusionPipeline, state: PipelineState) -> Pipelin
498594
denoising_start = state.get_input("denoising_start")
499595

500596
batch_size = state.get_intermediate("batch_size")
501-
prompt_embeds = state.get_intermediate("prompt_embeds", None)
597+
prompt_embeds = state.get_intermediate("prompt_embeds")
502598
# image to image only
503-
latent_timestep = state.get_intermediate("latent_timestep", None)
599+
latent_timestep = state.get_intermediate("latent_timestep")
504600

505601
if dtype is None and prompt_embeds is not None:
506602
dtype = prompt_embeds.dtype
@@ -1872,12 +1968,6 @@ def prepare_latents_img2img(
18721968
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
18731969
)
18741970

1875-
latents_mean = latents_std = None
1876-
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
1877-
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
1878-
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
1879-
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
1880-
18811971
# Offload text encoder if `enable_model_cpu_offload` was enabled
18821972
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
18831973
self.text_encoder_2.to("cpu")
@@ -1891,6 +1981,11 @@ def prepare_latents_img2img(
18911981
init_latents = image
18921982

18931983
else:
1984+
latents_mean = latents_std = None
1985+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
1986+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
1987+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
1988+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
18941989
# make sure the VAE is in float32 mode, as it overflows in float16
18951990
if self.vae.config.force_upcast:
18961991
image = image.float()

0 commit comments

Comments
 (0)