Skip to content

Commit 3287f4b

Browse files
committed
style
1 parent a9def70 commit 3287f4b

File tree

2 files changed

+43
-36
lines changed

2 files changed

+43
-36
lines changed

src/diffusers/models/autoencoders/autoencoder_kl_hunyuanimage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def __init__(
413413
latent_channels: int,
414414
block_out_channels: Tuple[int, ...],
415415
layers_per_block: int,
416-
ffactor_spatial: int, # YiYi Notes: rename this config to scale_factor_spatial be consistent with wan
416+
ffactor_spatial: int, # YiYi Notes: rename this config to scale_factor_spatial be consistent with wan
417417
sample_size: int,
418418
scaling_factor: float = None,
419419
downsample_match_channel: bool = True,

src/diffusers/pipelines/hunyuan_image/pipeline_hunyuanimage.py

Lines changed: 42 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
# limitations under the License.
1414

1515
import inspect
16+
import re
1617
from typing import Any, Callable, Dict, List, Optional, Union
1718

1819
import numpy as np
1920
import torch
20-
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel, ByT5Tokenizer
21+
from transformers import ByT5Tokenizer, Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer, T5EncoderModel
2122

2223
from ...image_processor import VaeImageProcessor
2324
from ...models import AutoencoderKLHunyuanImage, HunyuanImageTransformer2DModel
@@ -26,7 +27,6 @@
2627
from ...utils.torch_utils import randn_tensor
2728
from ..pipeline_utils import DiffusionPipeline
2829
from .pipeline_output import HunyuanImagePipelineOutput
29-
import re
3030

3131

3232
if is_torch_xla_available():
@@ -45,7 +45,9 @@
4545
>>> import torch
4646
>>> from diffusers import HunyuanImagePipeline
4747
48-
>>> pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanVideo", torch_dtype=torch.bfloat16)
48+
>>> pipe = HunyuanImagePipeline.from_pretrained(
49+
... "hunyuanvideo-community/HunyuanVideo", torch_dtype=torch.bfloat16
50+
... )
4951
>>> pipe.to("cuda")
5052
>>> prompt = "A cat holding a sign that says hello world"
5153
>>> # Depending on the variant being used, the pipeline call will slightly vary.
@@ -59,21 +61,20 @@
5961
def extract_glyph_text(prompt: str):
6062
"""
6163
Extract text enclosed in quotes for glyph rendering.
62-
63-
Finds text in single quotes, double quotes, and Chinese quotes,
64-
then formats it for byT5 processing.
65-
64+
65+
Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
66+
6667
Args:
6768
prompt: Input text prompt
68-
69+
6970
Returns:
7071
Formatted glyph text string or None if no quoted text found
7172
"""
7273
text_prompt_texts = []
73-
pattern_quote_single = r'\'(.*?)\''
74-
pattern_quote_double = r'\"(.*?)\"'
75-
pattern_quote_chinese_single = r'‘(.*?)’'
76-
pattern_quote_chinese_double = r'“(.*?)”'
74+
pattern_quote_single = r"\'(.*?)\'"
75+
pattern_quote_double = r"\"(.*?)\""
76+
pattern_quote_chinese_single = r"‘(.*?)’"
77+
pattern_quote_chinese_double = r"“(.*?)”"
7778

7879
matches_quote_single = re.findall(pattern_quote_single, prompt)
7980
matches_quote_double = re.findall(pattern_quote_double, prompt)
@@ -86,14 +87,13 @@ def extract_glyph_text(prompt: str):
8687
text_prompt_texts.extend(matches_quote_chinese_double)
8788

8889
if text_prompt_texts:
89-
glyph_text_formatted ='. '.join([f'Text "{text}"' for text in text_prompt_texts]) + '. '
90+
glyph_text_formatted = ". ".join([f'Text "{text}"' for text in text_prompt_texts]) + ". "
9091
else:
9192
glyph_text_formatted = None
9293

9394
return glyph_text_formatted
9495

9596

96-
9797
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9898
def retrieve_timesteps(
9999
scheduler,
@@ -170,8 +170,9 @@ class HunyuanImagePipeline(DiffusionPipeline):
170170
[Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
171171
tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
172172
text_encoder_2 ([`T5EncoderModel`]):
173-
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) variant.
174-
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
173+
[T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
174+
variant.
175+
tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
175176
"""
176177

177178
model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
@@ -244,7 +245,6 @@ def _get_qwen_prompt_embeds(
244245

245246
return prompt_embeds, encoder_attention_mask
246247

247-
248248
def _get_byt5_prompt_embeds(
249249
self,
250250
tokenizer: ByT5Tokenizer,
@@ -254,8 +254,6 @@ def _get_byt5_prompt_embeds(
254254
dtype: Optional[torch.dtype] = None,
255255
tokenizer_max_length: int = 128,
256256
):
257-
258-
259257
device = device or self._execution_device
260258
dtype = dtype or text_encoder.dtype
261259

@@ -278,7 +276,6 @@ def _get_byt5_prompt_embeds(
278276
attention_mask=txt_tokens.attention_mask.float(),
279277
)[0]
280278

281-
282279
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
283280
encoder_attention_mask = txt_tokens.attention_mask.to(device=device)
284281

@@ -304,13 +301,16 @@ def encode_prompt(
304301
num_images_per_prompt (`int`):
305302
number of images that should be generated per prompt
306303
prompt_embeds (`torch.Tensor`, *optional*):
307-
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input argument.
304+
Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
305+
argument.
308306
prompt_embeds_mask (`torch.Tensor`, *optional*):
309307
Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
310308
prompt_embeds_2 (`torch.Tensor`, *optional*):
311-
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input argument using self.tokenizer_2 and self.text_encoder_2.
309+
Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
310+
argument using self.tokenizer_2 and self.text_encoder_2.
312311
prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
313-
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input argument using self.tokenizer_2 and self.text_encoder_2.
312+
Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
313+
argument using self.tokenizer_2 and self.text_encoder_2.
314314
"""
315315
device = device or self._execution_device
316316

@@ -319,15 +319,15 @@ def encode_prompt(
319319

320320
if prompt_embeds is None:
321321
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(
322-
tokenizer=self.tokenizer,
323-
text_encoder=self.text_encoder,
324-
prompt=prompt,
322+
tokenizer=self.tokenizer,
323+
text_encoder=self.text_encoder,
324+
prompt=prompt,
325325
device=device,
326326
tokenizer_max_length=self.tokenizer_max_length,
327327
template=self.prompt_template_encode,
328328
drop_idx=self.prompt_template_encode_start_idx,
329329
)
330-
330+
331331
if prompt_embeds_2 is None:
332332
prompt_embeds_2_list = []
333333
prompt_embeds_mask_2_list = []
@@ -336,7 +336,9 @@ def encode_prompt(
336336
for glyph_text in glyph_texts:
337337
if glyph_text is None:
338338
glyph_text_embeds = torch.zeros((1, self.tokenizer_2_max_length, 1472), device=device)
339-
glyph_text_embeds_mask = torch.zeros((1, self.tokenizer_2_max_length), device=device, dtype=torch.int64)
339+
glyph_text_embeds_mask = torch.zeros(
340+
(1, self.tokenizer_2_max_length), device=device, dtype=torch.int64
341+
)
340342
else:
341343
glyph_text_embeds, glyph_text_embeds_mask = self._get_byt5_prompt_embeds(
342344
tokenizer=self.tokenizer_2,
@@ -345,10 +347,10 @@ def encode_prompt(
345347
device=device,
346348
tokenizer_max_length=self.tokenizer_2_max_length,
347349
)
348-
350+
349351
prompt_embeds_2_list.append(glyph_text_embeds)
350352
prompt_embeds_mask_2_list.append(glyph_text_embeds_mask)
351-
353+
352354
prompt_embeds_2 = torch.cat(prompt_embeds_2_list, dim=0)
353355
prompt_embeds_mask_2 = torch.cat(prompt_embeds_mask_2_list, dim=0)
354356

@@ -425,7 +427,7 @@ def check_inputs(
425427
raise ValueError(
426428
"Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
427429
)
428-
430+
429431
if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None:
430432
raise ValueError(
431433
"If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
@@ -435,7 +437,6 @@ def check_inputs(
435437
"If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
436438
)
437439

438-
439440
def prepare_latents(
440441
self,
441442
batch_size,
@@ -660,7 +661,12 @@ def __call__(
660661
prompt_embeds_2 = prompt_embeds_2.to(self.transformer.dtype)
661662

662663
if do_true_cfg:
663-
negative_prompt_embeds, negative_prompt_embeds_mask, negative_prompt_embeds_2, negative_prompt_embeds_mask_2 = self.encode_prompt(
664+
(
665+
negative_prompt_embeds,
666+
negative_prompt_embeds_mask,
667+
negative_prompt_embeds_2,
668+
negative_prompt_embeds_mask_2,
669+
) = self.encode_prompt(
664670
prompt=negative_prompt,
665671
prompt_embeds=negative_prompt_embeds,
666672
prompt_embeds_mask=negative_prompt_embeds_mask,
@@ -697,7 +703,9 @@ def __call__(
697703
if self.transformer.config.guidance_embeds and guidance_scale is None:
698704
raise ValueError("guidance_scale is required for guidance-distilled model.")
699705
elif self.transformer.config.guidance_embeds:
700-
guidance = torch.tensor([guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device) * 1000.0
706+
guidance = (
707+
torch.tensor([guidance_scale] * latents.shape[0], dtype=self.transformer.dtype, device=device) * 1000.0
708+
)
701709
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
702710
logger.warning(
703711
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
@@ -709,7 +717,6 @@ def __call__(
709717
if self.attention_kwargs is None:
710718
self._attention_kwargs = {}
711719

712-
713720
# 6. Denoising loop
714721
self.scheduler.set_begin_index(0)
715722
with self.progress_bar(total=num_inference_steps) as progress_bar:

0 commit comments

Comments
 (0)