@@ -98,16 +98,17 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
9898 block_state = self .get_block_state (state )
9999
100100 if not block_state .output_type == "latent" :
101+ latents = block_state .latents
101102 # make sure the VAE is in float32 mode, as it overflows in float16
102103 block_state .needs_upcasting = components .vae .dtype == torch .float16 and components .vae .config .force_upcast
103104
104105 if block_state .needs_upcasting :
105106 self .upcast_vae (components )
106- block_state . latents = block_state . latents .to (next (iter (components .vae .post_quant_conv .parameters ())).dtype )
107- elif block_state . latents .dtype != components .vae .dtype :
107+ latents = latents .to (next (iter (components .vae .post_quant_conv .parameters ())).dtype )
108+ elif latents .dtype != components .vae .dtype :
108109 if torch .backends .mps .is_available ():
109110 # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
110- components .vae = components .vae .to (block_state . latents .dtype )
111+ components .vae = components .vae .to (latents .dtype )
111112
112113 # unscale/denormalize the latents
113114 # denormalize with the mean and std if available and not None
@@ -119,16 +120,16 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
119120 )
120121 if block_state .has_latents_mean and block_state .has_latents_std :
121122 block_state .latents_mean = (
122- torch .tensor (components .vae .config .latents_mean ).view (1 , 4 , 1 , 1 ).to (block_state . latents .device , block_state . latents .dtype )
123+ torch .tensor (components .vae .config .latents_mean ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
123124 )
124125 block_state .latents_std = (
125- torch .tensor (components .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (block_state . latents .device , block_state . latents .dtype )
126+ torch .tensor (components .vae .config .latents_std ).view (1 , 4 , 1 , 1 ).to (latents .device , latents .dtype )
126127 )
127- block_state . latents = block_state . latents * block_state .latents_std / components .vae .config .scaling_factor + block_state .latents_mean
128+ latents = latents * block_state .latents_std / components .vae .config .scaling_factor + block_state .latents_mean
128129 else :
129- block_state . latents = block_state . latents / components .vae .config .scaling_factor
130+ latents = latents / components .vae .config .scaling_factor
130131
131- block_state .images = components .vae .decode (block_state . latents , return_dict = False )[0 ]
132+ block_state .images = components .vae .decode (latents , return_dict = False )[0 ]
132133
133134 # cast back to fp16 if needed
134135 if block_state .needs_upcasting :
@@ -186,6 +187,7 @@ def __call__(self, components, state: PipelineState) -> PipelineState:
186187 return components , state
187188
188189
190+ # YiYi TODO: remove this, we don't need this in modular
189191class StableDiffusionXLOutputStep (PipelineBlock ):
190192 model_name = "stable-diffusion-xl"
191193
0 commit comments