Skip to content

Commit 1b4af6b

Browse files
committed
update
1 parent ea77fdc commit 1b4af6b

File tree

3 files changed

+12
-27
lines changed

3 files changed

+12
-27
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,11 +1697,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
16971697
InputParam("controlnet_conditioning_scale", default=1.0),
16981698
InputParam("guess_mode", default=False),
16991699
InputParam("num_images_per_prompt", default=1),
1700-
]
1701-
1702-
@property
1703-
def intermediate_inputs(self) -> List[InputParam]:
1704-
return [
17051700
InputParam(
17061701
"latents",
17071702
required=True,

src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,6 @@ def inputs(self) -> List[Tuple[str, Any]]:
179179
InputParam("image"),
180180
InputParam("mask_image"),
181181
InputParam("padding_mask_crop"),
182-
]
183-
184-
@property
185-
def intermediate_inputs(self) -> List[str]:
186-
return [
187182
InputParam(
188183
"images",
189184
type_hint=Union[List[PIL.Image.Image], List[torch.Tensor], List[np.array]],

src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -663,12 +663,11 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
663663
block_state.device = components._execution_device
664664
block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype
665665

666-
block_state.image = components.image_processor.preprocess(
666+
image = components.image_processor.preprocess(
667667
block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs
668668
)
669-
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
670-
671-
block_state.batch_size = block_state.image.shape[0]
669+
image = image.to(device=block_state.device, dtype=block_state.dtype)
670+
block_state.batch_size = image.shape[0]
672671

673672
# if generator is a list, make sure the length of it matches the length of images (both should be batch_size)
674673
if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size:
@@ -677,9 +676,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
677676
f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators."
678677
)
679678

680-
block_state.image_latents = self._encode_vae_image(
681-
components, image=block_state.image, generator=block_state.generator
682-
)
679+
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
683680

684681
self.set_block_state(state, block_state)
685682

@@ -850,34 +847,32 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
850847
block_state.crops_coords = None
851848
block_state.resize_mode = "default"
852849

853-
block_state.image = components.image_processor.preprocess(
850+
image = components.image_processor.preprocess(
854851
block_state.image,
855852
height=block_state.height,
856853
width=block_state.width,
857854
crops_coords=block_state.crops_coords,
858855
resize_mode=block_state.resize_mode,
859856
)
860-
block_state.image = block_state.image.to(dtype=torch.float32)
857+
image = image.to(dtype=torch.float32)
861858

862-
block_state.mask = components.mask_processor.preprocess(
859+
mask = components.mask_processor.preprocess(
863860
block_state.mask_image,
864861
height=block_state.height,
865862
width=block_state.width,
866863
resize_mode=block_state.resize_mode,
867864
crops_coords=block_state.crops_coords,
868865
)
869-
block_state.masked_image = block_state.image * (block_state.mask < 0.5)
866+
block_state.masked_image = image * (mask < 0.5)
870867

871-
block_state.batch_size = block_state.image.shape[0]
872-
block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype)
873-
block_state.image_latents = self._encode_vae_image(
874-
components, image=block_state.image, generator=block_state.generator
875-
)
868+
block_state.batch_size = image.shape[0]
869+
image = image.to(device=block_state.device, dtype=block_state.dtype)
870+
block_state.image_latents = self._encode_vae_image(components, image=image, generator=block_state.generator)
876871

877872
# 7. Prepare mask latent variables
878873
block_state.mask, block_state.masked_image_latents = self.prepare_mask_latents(
879874
components,
880-
block_state.mask,
875+
mask,
881876
block_state.masked_image,
882877
block_state.batch_size,
883878
block_state.height,

0 commit comments

Comments
 (0)