Skip to content

Commit 556c6e9

Browse files
committed
refactor clip
1 parent 845f303 commit 556c6e9

File tree

1 file changed

+63
-17
lines changed

1 file changed

+63
-17
lines changed

src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py

Lines changed: 63 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import torch
21+
from transformers import CLIPTextModel, CLIPTokenizer
2122

2223
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2324
from ...image_processor import VaeImageProcessor
@@ -140,7 +141,8 @@ def __init__(
140141
text_encoder: TextEncoder,
141142
transformer: HunyuanVideoTransformer3DModel,
142143
scheduler: KarrasDiffusionSchedulers,
143-
text_encoder_2: Optional[TextEncoder] = None,
144+
text_encoder_2: Optional[CLIPTextModel] = None,
145+
tokenizer_2: Optional[CLIPTokenizer] = None,
144146
):
145147
super().__init__()
146148

@@ -150,6 +152,7 @@ def __init__(
150152
transformer=transformer,
151153
scheduler=scheduler,
152154
text_encoder_2=text_encoder_2,
155+
tokenizer_2=tokenizer_2,
153156
)
154157
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
155158
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
@@ -246,6 +249,45 @@ def encode_prompt(
246249
attention_mask,
247250
)
248251

252+
def _get_clip_prompt_embeds(
253+
self,
254+
prompt: Union[str, List[str]],
255+
num_videos_per_prompt: int = 1,
256+
device: Optional[torch.device] = None,
257+
dtype: Optional[torch.dtype] = None,
258+
max_sequence_length: int = 77,
259+
):
260+
device = device or self._execution_device
261+
dtype = dtype or self.text_encoder_2.dtype
262+
263+
prompt = [prompt] if isinstance(prompt, str) else prompt
264+
batch_size = len(prompt)
265+
266+
text_inputs = self.tokenizer_2(
267+
prompt,
268+
padding="max_length",
269+
max_length=max_sequence_length,
270+
truncation=True,
271+
return_tensors="pt",
272+
)
273+
274+
text_input_ids = text_inputs.input_ids
275+
untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
276+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
277+
removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
278+
logger.warning(
279+
"The following part of your input was truncated because CLIP can only handle sequences up to"
280+
f" {max_sequence_length} tokens: {removed_text}"
281+
)
282+
283+
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False).pooler_output
284+
285+
# duplicate text embeddings for each generation per prompt, using mps friendly method
286+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt)
287+
prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, -1)
288+
289+
return prompt_embeds
290+
249291
def check_inputs(
250292
self,
251293
prompt,
@@ -577,7 +619,6 @@ def __call__(
577619
prompt_attention_mask: Optional[torch.Tensor] = None,
578620
output_type: Optional[str] = "pil",
579621
return_dict: bool = True,
580-
clip_skip: Optional[int] = None,
581622
callback_on_step_end: Optional[
582623
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
583624
] = None,
@@ -674,7 +715,6 @@ def __call__(
674715
)
675716

676717
self._guidance_scale = guidance_scale
677-
self._clip_skip = clip_skip
678718
self._interrupt = False
679719

680720
# 2. Define call parameters
@@ -698,27 +738,33 @@ def __call__(
698738
num_videos_per_prompt,
699739
prompt_embeds=prompt_embeds,
700740
attention_mask=prompt_attention_mask,
701-
clip_skip=self.clip_skip,
702741
data_type=data_type,
703742
)
704743

744+
# if self.text_encoder_2 is not None:
745+
# (
746+
# prompt_embeds_2,
747+
# prompt_mask_2,
748+
# ) = self.encode_prompt(
749+
# prompt,
750+
# device,
751+
# num_videos_per_prompt,
752+
# prompt_embeds=prompt_embeds_2,
753+
# attention_mask=None,
754+
# clip_skip=self.clip_skip,
755+
# text_encoder=self.text_encoder_2,
756+
# data_type=data_type,
757+
# )
758+
# else:
759+
# prompt_embeds_2 = None
760+
# prompt_mask_2 = None
761+
705762
if self.text_encoder_2 is not None:
706-
(
707-
prompt_embeds_2,
708-
prompt_mask_2,
709-
) = self.encode_prompt(
763+
prompt_embeds_2 = self._get_clip_prompt_embeds(
710764
prompt,
711-
device,
712765
num_videos_per_prompt,
713-
prompt_embeds=prompt_embeds_2,
714-
attention_mask=None,
715-
clip_skip=self.clip_skip,
716-
text_encoder=self.text_encoder_2,
717-
data_type=data_type,
766+
device=device,
718767
)
719-
else:
720-
prompt_embeds_2 = None
721-
prompt_mask_2 = None
722768

723769
# 4. Prepare timesteps
724770
timesteps, num_inference_steps = retrieve_timesteps(

0 commit comments

Comments
 (0)