1313# limitations under the License.
1414
1515import inspect
16+ import re
1617from typing import Any , Callable , Dict , List , Optional , Union
1718
1819import numpy as np
1920import torch
20- from transformers import Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer , T5EncoderModel , ByT5Tokenizer
21+ from transformers import ByT5Tokenizer , Qwen2_5_VLForConditionalGeneration , Qwen2Tokenizer , T5EncoderModel
2122
2223from ...image_processor import VaeImageProcessor
2324from ...models import AutoencoderKLHunyuanImage , HunyuanImageTransformer2DModel
2627from ...utils .torch_utils import randn_tensor
2728from ..pipeline_utils import DiffusionPipeline
2829from .pipeline_output import HunyuanImagePipelineOutput
29- import re
3030
3131
3232if is_torch_xla_available ():
4545 >>> import torch
4646 >>> from diffusers import HunyuanImagePipeline
4747
48- >>> pipe = HunyuanImagePipeline.from_pretrained("hunyuanvideo-community/HunyuanVideo", torch_dtype=torch.bfloat16)
48+ >>> pipe = HunyuanImagePipeline.from_pretrained(
49+ ... "hunyuanvideo-community/HunyuanVideo", torch_dtype=torch.bfloat16
50+ ... )
4951 >>> pipe.to("cuda")
5052 >>> prompt = "A cat holding a sign that says hello world"
5153 >>> # Depending on the variant being used, the pipeline call will slightly vary.
5961def extract_glyph_text (prompt : str ):
6062 """
6163 Extract text enclosed in quotes for glyph rendering.
62-
63- Finds text in single quotes, double quotes, and Chinese quotes,
64- then formats it for byT5 processing.
65-
64+
65+ Finds text in single quotes, double quotes, and Chinese quotes, then formats it for byT5 processing.
66+
6667 Args:
6768 prompt: Input text prompt
68-
69+
6970 Returns:
7071 Formatted glyph text string or None if no quoted text found
7172 """
7273 text_prompt_texts = []
73- pattern_quote_single = r' \'(.*?)\''
74- pattern_quote_double = r' \"(.*?)\"'
75- pattern_quote_chinese_single = r' ‘(.*?)’'
76- pattern_quote_chinese_double = r' “(.*?)”'
74+ pattern_quote_single = r" \'(.*?)\'"
75+ pattern_quote_double = r" \"(.*?)\""
76+ pattern_quote_chinese_single = r" ‘(.*?)’"
77+ pattern_quote_chinese_double = r" “(.*?)”"
7778
7879 matches_quote_single = re .findall (pattern_quote_single , prompt )
7980 matches_quote_double = re .findall (pattern_quote_double , prompt )
@@ -86,14 +87,13 @@ def extract_glyph_text(prompt: str):
8687 text_prompt_texts .extend (matches_quote_chinese_double )
8788
8889 if text_prompt_texts :
89- glyph_text_formatted = '. ' .join ([f'Text "{ text } "' for text in text_prompt_texts ]) + '. '
90+ glyph_text_formatted = ". " .join ([f'Text "{ text } "' for text in text_prompt_texts ]) + ". "
9091 else :
9192 glyph_text_formatted = None
9293
9394 return glyph_text_formatted
9495
9596
96-
9797# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
9898def retrieve_timesteps (
9999 scheduler ,
@@ -170,8 +170,9 @@ class HunyuanImagePipeline(DiffusionPipeline):
170170 [Qwen2.5-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2.5-VL-7B-Instruct) variant.
171171 tokenizer (`Qwen2Tokenizer`): Tokenizer of class [Qwen2Tokenizer].
172172 text_encoder_2 ([`T5EncoderModel`]):
173- [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel) variant.
174- tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
173+ [T5EncoderModel](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel)
174+ variant.
175+ tokenizer_2 (`ByT5Tokenizer`): Tokenizer of class [ByT5Tokenizer]
175176 """
176177
177178 model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
@@ -244,7 +245,6 @@ def _get_qwen_prompt_embeds(
244245
245246 return prompt_embeds , encoder_attention_mask
246247
247-
248248 def _get_byt5_prompt_embeds (
249249 self ,
250250 tokenizer : ByT5Tokenizer ,
@@ -254,8 +254,6 @@ def _get_byt5_prompt_embeds(
254254 dtype : Optional [torch .dtype ] = None ,
255255 tokenizer_max_length : int = 128 ,
256256 ):
257-
258-
259257 device = device or self ._execution_device
260258 dtype = dtype or text_encoder .dtype
261259
@@ -278,7 +276,6 @@ def _get_byt5_prompt_embeds(
278276 attention_mask = txt_tokens .attention_mask .float (),
279277 )[0 ]
280278
281-
282279 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
283280 encoder_attention_mask = txt_tokens .attention_mask .to (device = device )
284281
@@ -304,13 +301,16 @@ def encode_prompt(
304301 num_images_per_prompt (`int`):
305302 number of images that should be generated per prompt
306303 prompt_embeds (`torch.Tensor`, *optional*):
307- Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input argument.
304+ Pre-generated text embeddings. If not provided, text embeddings will be generated from `prompt` input
305+ argument.
308306 prompt_embeds_mask (`torch.Tensor`, *optional*):
309307 Pre-generated text mask. If not provided, text mask will be generated from `prompt` input argument.
310308 prompt_embeds_2 (`torch.Tensor`, *optional*):
311- Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input argument using self.tokenizer_2 and self.text_encoder_2.
309+ Pre-generated glyph text embeddings from ByT5. If not provided, will be generated from `prompt` input
310+ argument using self.tokenizer_2 and self.text_encoder_2.
312311 prompt_embeds_mask_2 (`torch.Tensor`, *optional*):
313- Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input argument using self.tokenizer_2 and self.text_encoder_2.
312+ Pre-generated glyph text mask from ByT5. If not provided, will be generated from `prompt` input
313+ argument using self.tokenizer_2 and self.text_encoder_2.
314314 """
315315 device = device or self ._execution_device
316316
@@ -319,15 +319,15 @@ def encode_prompt(
319319
320320 if prompt_embeds is None :
321321 prompt_embeds , prompt_embeds_mask = self ._get_qwen_prompt_embeds (
322- tokenizer = self .tokenizer ,
323- text_encoder = self .text_encoder ,
324- prompt = prompt ,
322+ tokenizer = self .tokenizer ,
323+ text_encoder = self .text_encoder ,
324+ prompt = prompt ,
325325 device = device ,
326326 tokenizer_max_length = self .tokenizer_max_length ,
327327 template = self .prompt_template_encode ,
328328 drop_idx = self .prompt_template_encode_start_idx ,
329329 )
330-
330+
331331 if prompt_embeds_2 is None :
332332 prompt_embeds_2_list = []
333333 prompt_embeds_mask_2_list = []
@@ -336,7 +336,9 @@ def encode_prompt(
336336 for glyph_text in glyph_texts :
337337 if glyph_text is None :
338338 glyph_text_embeds = torch .zeros ((1 , self .tokenizer_2_max_length , 1472 ), device = device )
339- glyph_text_embeds_mask = torch .zeros ((1 , self .tokenizer_2_max_length ), device = device , dtype = torch .int64 )
339+ glyph_text_embeds_mask = torch .zeros (
340+ (1 , self .tokenizer_2_max_length ), device = device , dtype = torch .int64
341+ )
340342 else :
341343 glyph_text_embeds , glyph_text_embeds_mask = self ._get_byt5_prompt_embeds (
342344 tokenizer = self .tokenizer_2 ,
@@ -345,10 +347,10 @@ def encode_prompt(
345347 device = device ,
346348 tokenizer_max_length = self .tokenizer_2_max_length ,
347349 )
348-
350+
349351 prompt_embeds_2_list .append (glyph_text_embeds )
350352 prompt_embeds_mask_2_list .append (glyph_text_embeds_mask )
351-
353+
352354 prompt_embeds_2 = torch .cat (prompt_embeds_2_list , dim = 0 )
353355 prompt_embeds_mask_2 = torch .cat (prompt_embeds_mask_2_list , dim = 0 )
354356
@@ -425,7 +427,7 @@ def check_inputs(
425427 raise ValueError (
426428 "Provide either `prompt` or `prompt_embeds_2`. Cannot leave both `prompt` and `prompt_embeds_2` undefined."
427429 )
428-
430+
429431 if prompt_embeds_2 is not None and prompt_embeds_mask_2 is None :
430432 raise ValueError (
431433 "If `prompt_embeds_2` are provided, `prompt_embeds_mask_2` also have to be passed. Make sure to generate `prompt_embeds_mask_2` from the same text encoder that was used to generate `prompt_embeds_2`."
@@ -435,7 +437,6 @@ def check_inputs(
435437 "If `negative_prompt_embeds_2` are provided, `negative_prompt_embeds_mask_2` also have to be passed. Make sure to generate `negative_prompt_embeds_mask_2` from the same text encoder that was used to generate `negative_prompt_embeds_2`."
436438 )
437439
438-
439440 def prepare_latents (
440441 self ,
441442 batch_size ,
@@ -660,7 +661,12 @@ def __call__(
660661 prompt_embeds_2 = prompt_embeds_2 .to (self .transformer .dtype )
661662
662663 if do_true_cfg :
663- negative_prompt_embeds , negative_prompt_embeds_mask , negative_prompt_embeds_2 , negative_prompt_embeds_mask_2 = self .encode_prompt (
664+ (
665+ negative_prompt_embeds ,
666+ negative_prompt_embeds_mask ,
667+ negative_prompt_embeds_2 ,
668+ negative_prompt_embeds_mask_2 ,
669+ ) = self .encode_prompt (
664670 prompt = negative_prompt ,
665671 prompt_embeds = negative_prompt_embeds ,
666672 prompt_embeds_mask = negative_prompt_embeds_mask ,
@@ -697,7 +703,9 @@ def __call__(
697703 if self .transformer .config .guidance_embeds and guidance_scale is None :
698704 raise ValueError ("guidance_scale is required for guidance-distilled model." )
699705 elif self .transformer .config .guidance_embeds :
700- guidance = torch .tensor ([guidance_scale ] * latents .shape [0 ], dtype = self .transformer .dtype , device = device ) * 1000.0
706+ guidance = (
707+ torch .tensor ([guidance_scale ] * latents .shape [0 ], dtype = self .transformer .dtype , device = device ) * 1000.0
708+ )
701709 elif not self .transformer .config .guidance_embeds and guidance_scale is not None :
702710 logger .warning (
703711 f"guidance_scale is passed as { guidance_scale } , but ignored since the model is not guidance-distilled."
@@ -709,7 +717,6 @@ def __call__(
709717 if self .attention_kwargs is None :
710718 self ._attention_kwargs = {}
711719
712-
713720 # 6. Denoising loop
714721 self .scheduler .set_begin_index (0 )
715722 with self .progress_bar (total = num_inference_steps ) as progress_bar :
0 commit comments