1313# limitations under the License.
1414
1515import inspect
16- import re
1716from typing import Any , Callable , Dict , List , Optional , Union
1817
1918import numpy as np
2019import torch
21- from transformers import ByT5Tokenizer , Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer , T5EncoderModel
20+ from transformers import Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer
2221
23- from ...image_processor import VaeImageProcessor , PipelineImageInput
22+ from ...image_processor import PipelineImageInput , VaeImageProcessor
2423from ...models import AutoencoderKLHunyuanImageRefiner , HunyuanImageTransformer2DModel
2524from ...schedulers import FlowMatchEulerDiscreteScheduler
2625from ...utils import is_torch_xla_available , logging , replace_example_docstring
5756 ```
5857"""
5958
59+
6060# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
6161def retrieve_timesteps (
6262 scheduler ,
@@ -128,7 +128,7 @@ def retrieve_latents(
128128 elif hasattr (encoder_output , "latents" ):
129129 return encoder_output .latents
130130 else :
131- raise AttributeError ("Could not access latents of provided encoder_output" )
131+ raise AttributeError ("Could not access latents of provided encoder_output" )
132132
133133
134134class HunyuanImageRefinerPipeline (DiffusionPipeline ):
@@ -358,8 +358,7 @@ def prepare_latents(
358358
359359 latents = strength * noise + (1 - strength ) * image_latents
360360
361- return noise ,latents
362-
361+ return noise , latents
363362
364363 def _encode_vae_image (self , image : torch .Tensor , generator : torch .Generator ):
365364 if isinstance (generator , list ):
@@ -370,9 +369,10 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
370369 image_latents = torch .cat (image_latents , dim = 0 )
371370 else :
372371 image_latents = retrieve_latents (self .vae .encode (image ), generator = generator , sample_mode = "sample" )
373-
372+
374373 # rearrange tokens
375- from einops import rearrange # YiYi TODO: remove this dependency
374+ from einops import rearrange # YiYi TODO: remove this dependency
375+
376376 image_latents = torch .cat ((image_latents [:, :, :1 ], image_latents ), dim = 2 )
377377 image_latents = rearrange (image_latents , "b c f h w -> b f c h w" )
378378 image_latents = rearrange (image_latents , "b (f n) c h w -> b f (n c) h w" , n = 2 )
@@ -556,7 +556,6 @@ def __call__(
556556
557557 image_latents = self ._encode_vae_image (image = image , generator = generator )
558558
559-
560559 has_neg_prompt = negative_prompt is not None or (
561560 negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
562561 )
@@ -708,7 +707,8 @@ def __call__(
708707 else :
709708 latents = latents .to (self .vae .dtype ) / self .vae .config .scaling_factor
710709
711- from einops import rearrange # YiYi TODO: remove this dependency
710+ from einops import rearrange # YiYi TODO: remove this dependency
711+
712712 latents = rearrange (latents , "b c f h w -> b f c h w" )
713713 latents = rearrange (latents , "b f (n c) h w -> b (f n) c h w" , n = 2 )
714714 latents = rearrange (latents , "b f c h w -> b c f h w" )
0 commit comments