Skip to content

Commit 9e8b94a

Browse files
committed
up
1 parent aef133d commit 9e8b94a

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage_refiner.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,13 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import re
1716
from typing import Any, Callable, Dict, List, Optional, Union
1817

1918
import numpy as np
2019
import torch
21-
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
20+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
2221

23-
from ...image_processor import VaeImageProcessor, PipelineImageInput
22+
from ...image_processor import PipelineImageInput, VaeImageProcessor
2423
from ...models import AutoencoderKLHunyuanImageRefiner, HunyuanImageTransformer2DModel
2524
from ...schedulers import FlowMatchEulerDiscreteScheduler
2625
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -57,6 +56,7 @@
5756
```
5857
"""
5958

59+
6060
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
6161
def retrieve_timesteps(
6262
scheduler,
@@ -128,7 +128,7 @@ def retrieve_latents(
128128
elif hasattr(encoder_output, "latents"):
129129
return encoder_output.latents
130130
else:
131-
raise AttributeError("Could not access latents of provided encoder_output")
131+
raise AttributeError("Could not access latents of provided encoder_output")
132132

133133

134134
class HunyuanImageRefinerPipeline(DiffusionPipeline):
@@ -358,8 +358,7 @@ def prepare_latents(
358358

359359
latents = strength * noise + (1 - strength) * image_latents
360360

361-
return noise,latents
362-
361+
return noise, latents
363362

364363
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
365364
if isinstance(generator, list):
@@ -370,9 +369,10 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
370369
image_latents = torch.cat(image_latents, dim=0)
371370
else:
372371
image_latents = retrieve_latents(self.vae.encode(image), generator=generator, sample_mode="sample")
373-
372+
374373
# rearrange tokens
375-
from einops import rearrange # YiYi TODO: remove this dependency
374+
from einops import rearrange # YiYi TODO: remove this dependency
375+
376376
image_latents = torch.cat((image_latents[:, :, :1], image_latents), dim=2)
377377
image_latents = rearrange(image_latents, "b c f h w -> b f c h w")
378378
image_latents = rearrange(image_latents, "b (f n) c h w -> b f (n c) h w", n=2)
@@ -556,7 +556,6 @@ def __call__(
556556

557557
image_latents = self._encode_vae_image(image=image, generator=generator)
558558

559-
560559
has_neg_prompt = negative_prompt is not None or (
561560
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
562561
)
@@ -708,7 +707,8 @@ def __call__(
708707
else:
709708
latents = latents.to(self.vae.dtype) / self.vae.config.scaling_factor
710709

711-
from einops import rearrange # YiYi TODO: remove this dependency
710+
from einops import rearrange # YiYi TODO: remove this dependency
711+
712712
latents = rearrange(latents, "b c f h w -> b f c h w")
713713
latents = rearrange(latents, "b f (n c) h w -> b (f n) c h w", n=2)
714714
latents = rearrange(latents, "b f c h w -> b c f h w")

0 commit comments

Comments
 (0)