Skip to content

Commit 4b367e8

Browse files
committed
up
1 parent ed881a1 commit 4b367e8

File tree

4 files changed

+31
-21
lines changed

4 files changed

+31
-21
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,8 @@ def intermediate_outputs(self) -> List[str]:
607607
),
608608
]
609609

610+
@staticmethod
610611
def prepare_latents(
611-
self,
612612
image_latents,
613613
scheduler,
614614
dtype,
@@ -760,6 +760,7 @@ def check_inputs(self, batch_size, image_latents):
760760
if not (image_latents.shape[0] == 1 or image_latents.shape[0] == batch_size):
761761
raise ValueError(f"image_latents should have have batch size 1 or {batch_size}, but got {image_latents.shape[0]}")
762762

763+
@staticmethod
763764
def prepare_latents(image_latents, scheduler, timestep, dtype, device, generator=None):
764765
if isinstance(generator, list) and len(generator) != image_latents.shape[0]:
765766
raise ValueError(
@@ -975,6 +976,7 @@ def intermediate_inputs(self) -> List[InputParam]:
975976
type_hint=int,
976977
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
977978
),
979+
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
978980
]
979981

980982
@property
@@ -992,7 +994,6 @@ def intermediate_outputs(self) -> List[OutputParam]:
992994
kwargs_type="guider_input_fields",
993995
description="The negative time ids to condition the denoising process",
994996
),
995-
OutputParam("timestep_cond", type_hint=torch.Tensor, description="The timestep cond to use for LCM"),
996997
]
997998

998999
@staticmethod
@@ -1136,6 +1137,11 @@ def intermediate_inputs(self) -> List[InputParam]:
11361137
type_hint=int,
11371138
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
11381139
),
1140+
InputParam(
1141+
"dtype",
1142+
type_hint=torch.dtype,
1143+
description="The dtype of the model inputs. Can be generated in input step.",
1144+
),
11391145
]
11401146

11411147
@property
@@ -1187,8 +1193,8 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
11871193
_, _, height_latents, width_latents = block_state.latents.shape
11881194
height = height_latents * components.vae_scale_factor
11891195
width = width_latents * components.vae_scale_factor
1190-
original_size = block_state.original_size or (block_state.height, block_state.width)
1191-
target_size = block_state.target_size or (block_state.height, block_state.width)
1196+
original_size = block_state.original_size or (height, width)
1197+
target_size = block_state.target_size or (height, width)
11921198

11931199

11941200
block_state.add_time_ids = self._get_add_time_ids(

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
139139
block_state.images = components.vae.decode(latents, return_dict=False)[0]
140140

141141
# cast back to fp16 if needed
142-
if block_state.needs_upcasting:
142+
if needs_upcasting:
143143
components.vae.to(dtype=torch.float16)
144144
else:
145145
block_state.images = block_state.latents

src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
4040
from .modular_pipeline import StableDiffusionXLModularPipeline
4141

42+
from PIL import Image
43+
4244

4345
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4446

@@ -504,11 +506,11 @@ def encode_prompt(
504506
negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
505507
negative_pooled_prompt_embeds = torch.concat(negative_pooled_prompt_embeds_list, dim=0)
506508

507-
prompt_embeds = prompt_embeds.to(dtype, device=device)
508-
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype, device=device)
509+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
510+
pooled_prompt_embeds = pooled_prompt_embeds.to(dtype=dtype, device=device)
509511
if requires_unconditional_embeds:
510-
negative_prompt_embeds = negative_prompt_embeds.to(dtype, device=device)
511-
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype, device=device)
512+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
513+
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.to(dtype=dtype, device=device)
512514

513515
for text_encoder in text_encoders:
514516
if isinstance(components, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
@@ -687,12 +689,12 @@ def intermediate_outputs(self) -> List[OutputParam]:
687689

688690
def check_inputs(self, image, mask_image, padding_mask_crop):
689691

690-
if padding_mask_crop is not None and not isinstance(image, PIL.Image.Image):
692+
if padding_mask_crop is not None and not isinstance(image, Image.Image):
691693
raise ValueError(
692694
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
693695
)
694696

695-
if padding_mask_crop is not None and not isinstance(mask_image, PIL.Image.Image):
697+
if padding_mask_crop is not None and not isinstance(mask_image, Image.Image):
696698
raise ValueError(
697699
f"The mask image should be a PIL image when inpainting mask crop, but is of type"
698700
f" {type(mask_image)}."
@@ -707,10 +709,8 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
707709
dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
708710
device = components._execution_device
709711

710-
if block_state.height is None:
711-
height = components.default_height
712-
if block_state.width is None:
713-
width = components.default_width
712+
height = block_state.height if block_state.height is not None else components.default_height
713+
width = block_state.width if block_state.width is not None else components.default_width
714714

715715
if block_state.padding_mask_crop is not None:
716716
block_state.crops_coords = components.mask_processor.get_crop_region(
@@ -725,21 +725,21 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
725725
block_state.image,
726726
height=height,
727727
width=width,
728-
crops_coords=crops_coords,
728+
crops_coords=block_state.crops_coords,
729729
resize_mode=resize_mode,
730730
)
731731

732732
image = image.to(dtype=torch.float32)
733733

734-
mask = components.mask_processor.preprocess(
734+
mask_image = components.mask_processor.preprocess(
735735
block_state.mask_image,
736736
height=height,
737737
width=width,
738738
resize_mode=resize_mode,
739-
crops_coords=crops_coords,
739+
crops_coords=block_state.crops_coords,
740740
)
741741

742-
masked_image = image * (block_state.mask_latents < 0.5)
742+
masked_image = image * (mask_image < 0.5)
743743

744744
# Prepare image latent variables
745745
block_state.image_latents = encode_vae_image(
@@ -762,7 +762,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
762762
# resize mask to match the image latents
763763
_, _, height_latents, width_latents = block_state.image_latents.shape
764764
block_state.mask = torch.nn.functional.interpolate(
765-
mask,
765+
mask_image,
766766
size=(height_latents, width_latents),
767767
)
768768
block_state.mask = block_state.mask.to(dtype=dtype, device=device)

src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,15 @@ def requires_unconditional_embeds(self):
9595
# by default, always prepare unconditional embeddings
9696
requires_unconditional_embeds = True
9797

98-
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is None:
98+
if hasattr(self, "unet") and self.unet is not None and self.unet.config.time_cond_proj_dim is not None:
99+
# LCM
99100
requires_unconditional_embeds = False
100101

101102
elif hasattr(self, "guider") and self.guider is not None:
102103
requires_unconditional_embeds = self.guider.num_conditions > 1
104+
105+
elif not hasattr(self, "guider") or self.guider is None:
106+
requires_unconditional_embeds = False
103107

104108
return requires_unconditional_embeds
105109

0 commit comments

Comments
 (0)