Skip to content

Commit c751449

Browse files
authored
fix: use retrieve_latents (#6337)
1 parent c1e8bdf commit c751449

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

examples/community/stable_diffusion_tensorrt_img2img.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
StableDiffusionPipelineOutput,
5151
StableDiffusionSafetyChecker,
5252
)
53+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import retrieve_latents
5354
from diffusers.schedulers import DDIMScheduler
5455
from diffusers.utils import logging
5556

@@ -608,7 +609,7 @@ def __init__(self, model):
608609
self.vae_encoder = model
609610

610611
def forward(self, x):
611-
return self.vae_encoder.encode(x).latent_dist.sample()
612+
return retrieve_latents(self.vae_encoder.encode(x))
612613

613614

614615
class VAEEncoder(BaseModel):

0 commit comments

Comments
 (0)