Skip to content

Commit 58358c2

Browse files
committed
decode block, if skip decoding do not need to update latent
1 parent 5cde77f commit 58358c2

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/after_denoise.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -98,16 +98,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
9898
block_state = self.get_block_state(state)
9999

100100
if not block_state.output_type == "latent":
101+
latents = block_state.latents
101102
# make sure the VAE is in float32 mode, as it overflows in float16
102103
block_state.needs_upcasting = components.vae.dtype == torch.float16 and components.vae.config.force_upcast
103104

104105
if block_state.needs_upcasting:
105106
self.upcast_vae(components)
106-
block_state.latents = block_state.latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
107-
elif block_state.latents.dtype != components.vae.dtype:
107+
latents = latents.to(next(iter(components.vae.post_quant_conv.parameters())).dtype)
108+
elif latents.dtype != components.vae.dtype:
108109
if torch.backends.mps.is_available():
109110
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
110-
components.vae = components.vae.to(block_state.latents.dtype)
111+
components.vae = components.vae.to(latents.dtype)
111112

112113
# unscale/denormalize the latents
113114
# denormalize with the mean and std if available and not None
@@ -119,16 +120,16 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
119120
)
120121
if block_state.has_latents_mean and block_state.has_latents_std:
121122
block_state.latents_mean = (
122-
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
123+
torch.tensor(components.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
123124
)
124125
block_state.latents_std = (
125-
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(block_state.latents.device, block_state.latents.dtype)
126+
torch.tensor(components.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
126127
)
127-
block_state.latents = block_state.latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
128+
latents = latents * block_state.latents_std / components.vae.config.scaling_factor + block_state.latents_mean
128129
else:
129-
block_state.latents = block_state.latents / components.vae.config.scaling_factor
130+
latents = latents / components.vae.config.scaling_factor
130131

131-
block_state.images = components.vae.decode(block_state.latents, return_dict=False)[0]
132+
block_state.images = components.vae.decode(latents, return_dict=False)[0]
132133

133134
# cast back to fp16 if needed
134135
if block_state.needs_upcasting:
@@ -186,6 +187,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
186187
return components, state
187188

188189

190+
# YiYi TODO: remove this, we don't need this in modular
189191
class StableDiffusionXLOutputStep(PipelineBlock):
190192
model_name = "stable-diffusion-xl"
191193

0 commit comments

Comments
 (0)