|  | 
|  | 1 | +# Copyright 2025 The HuggingFace Team. All rights reserved. | 
|  | 2 | +# | 
|  | 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); | 
|  | 4 | +# you may not use this file except in compliance with the License. | 
|  | 5 | +# You may obtain a copy of the License at | 
|  | 6 | +# | 
|  | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0 | 
|  | 8 | +# | 
|  | 9 | +# Unless required by applicable law or agreed to in writing, software | 
|  | 10 | +# distributed under the License is distributed on an "AS IS" BASIS, | 
|  | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | 
|  | 12 | +# See the License for the specific language governing permissions and | 
|  | 13 | +# limitations under the License. | 
|  | 14 | + | 
|  | 15 | +import html | 
|  | 16 | +from typing import List, Optional, Union | 
|  | 17 | + | 
|  | 18 | +import regex as re | 
|  | 19 | +import torch | 
|  | 20 | +from transformers import AutoTokenizer, UMT5EncoderModel | 
|  | 21 | + | 
|  | 22 | +from ...configuration_utils import FrozenDict | 
|  | 23 | +from ...guiders import ClassifierFreeGuidance | 
|  | 24 | +from ...utils import is_ftfy_available, logging | 
|  | 25 | +from ..modular_pipeline import PipelineBlock, PipelineState | 
|  | 26 | +from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam | 
|  | 27 | +from .modular_pipeline import WanModularPipeline | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +if is_ftfy_available(): | 
|  | 31 | +    import ftfy | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +logger = logging.get_logger(__name__)  # pylint: disable=invalid-name | 
|  | 35 | + | 
|  | 36 | + | 
|  | 37 | +def basic_clean(text): | 
|  | 38 | +    text = ftfy.fix_text(text) | 
|  | 39 | +    text = html.unescape(html.unescape(text)) | 
|  | 40 | +    return text.strip() | 
|  | 41 | + | 
|  | 42 | + | 
|  | 43 | +def whitespace_clean(text): | 
|  | 44 | +    text = re.sub(r"\s+", " ", text) | 
|  | 45 | +    text = text.strip() | 
|  | 46 | +    return text | 
|  | 47 | + | 
|  | 48 | + | 
|  | 49 | +def prompt_clean(text): | 
|  | 50 | +    text = whitespace_clean(basic_clean(text)) | 
|  | 51 | +    return text | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +class WanTextEncoderStep(PipelineBlock): | 
|  | 55 | +    model_name = "wan" | 
|  | 56 | + | 
|  | 57 | +    @property | 
|  | 58 | +    def description(self) -> str: | 
|  | 59 | +        return "Text Encoder step that generate text_embeddings to guide the video generation" | 
|  | 60 | + | 
|  | 61 | +    @property | 
|  | 62 | +    def expected_components(self) -> List[ComponentSpec]: | 
|  | 63 | +        return [ | 
|  | 64 | +            ComponentSpec("text_encoder", UMT5EncoderModel), | 
|  | 65 | +            ComponentSpec("tokenizer", AutoTokenizer), | 
|  | 66 | +            ComponentSpec( | 
|  | 67 | +                "guider", | 
|  | 68 | +                ClassifierFreeGuidance, | 
|  | 69 | +                config=FrozenDict({"guidance_scale": 5.0}), | 
|  | 70 | +                default_creation_method="from_config", | 
|  | 71 | +            ), | 
|  | 72 | +        ] | 
|  | 73 | + | 
|  | 74 | +    @property | 
|  | 75 | +    def expected_configs(self) -> List[ConfigSpec]: | 
|  | 76 | +        return [] | 
|  | 77 | + | 
|  | 78 | +    @property | 
|  | 79 | +    def inputs(self) -> List[InputParam]: | 
|  | 80 | +        return [ | 
|  | 81 | +            InputParam("prompt"), | 
|  | 82 | +            InputParam("negative_prompt"), | 
|  | 83 | +            InputParam("attention_kwargs"), | 
|  | 84 | +        ] | 
|  | 85 | + | 
|  | 86 | +    @property | 
|  | 87 | +    def intermediate_outputs(self) -> List[OutputParam]: | 
|  | 88 | +        return [ | 
|  | 89 | +            OutputParam( | 
|  | 90 | +                "prompt_embeds", | 
|  | 91 | +                type_hint=torch.Tensor, | 
|  | 92 | +                kwargs_type="guider_input_fields", | 
|  | 93 | +                description="text embeddings used to guide the image generation", | 
|  | 94 | +            ), | 
|  | 95 | +            OutputParam( | 
|  | 96 | +                "negative_prompt_embeds", | 
|  | 97 | +                type_hint=torch.Tensor, | 
|  | 98 | +                kwargs_type="guider_input_fields", | 
|  | 99 | +                description="negative text embeddings used to guide the image generation", | 
|  | 100 | +            ), | 
|  | 101 | +        ] | 
|  | 102 | + | 
|  | 103 | +    @staticmethod | 
|  | 104 | +    def check_inputs(block_state): | 
|  | 105 | +        if block_state.prompt is not None and ( | 
|  | 106 | +            not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list) | 
|  | 107 | +        ): | 
|  | 108 | +            raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}") | 
|  | 109 | + | 
|  | 110 | +    @staticmethod | 
|  | 111 | +    def _get_t5_prompt_embeds( | 
|  | 112 | +        components, | 
|  | 113 | +        prompt: Union[str, List[str]], | 
|  | 114 | +        max_sequence_length: int, | 
|  | 115 | +        device: torch.device, | 
|  | 116 | +    ): | 
|  | 117 | +        dtype = components.text_encoder.dtype | 
|  | 118 | +        prompt = [prompt] if isinstance(prompt, str) else prompt | 
|  | 119 | +        prompt = [prompt_clean(u) for u in prompt] | 
|  | 120 | + | 
|  | 121 | +        text_inputs = components.tokenizer( | 
|  | 122 | +            prompt, | 
|  | 123 | +            padding="max_length", | 
|  | 124 | +            max_length=max_sequence_length, | 
|  | 125 | +            truncation=True, | 
|  | 126 | +            add_special_tokens=True, | 
|  | 127 | +            return_attention_mask=True, | 
|  | 128 | +            return_tensors="pt", | 
|  | 129 | +        ) | 
|  | 130 | +        text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask | 
|  | 131 | +        seq_lens = mask.gt(0).sum(dim=1).long() | 
|  | 132 | +        prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state | 
|  | 133 | +        prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | 
|  | 134 | +        prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] | 
|  | 135 | +        prompt_embeds = torch.stack( | 
|  | 136 | +            [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 | 
|  | 137 | +        ) | 
|  | 138 | + | 
|  | 139 | +        return prompt_embeds | 
|  | 140 | + | 
|  | 141 | +    @staticmethod | 
|  | 142 | +    def encode_prompt( | 
|  | 143 | +        components, | 
|  | 144 | +        prompt: str, | 
|  | 145 | +        device: Optional[torch.device] = None, | 
|  | 146 | +        num_videos_per_prompt: int = 1, | 
|  | 147 | +        prepare_unconditional_embeds: bool = True, | 
|  | 148 | +        negative_prompt: Optional[str] = None, | 
|  | 149 | +        prompt_embeds: Optional[torch.Tensor] = None, | 
|  | 150 | +        negative_prompt_embeds: Optional[torch.Tensor] = None, | 
|  | 151 | +        max_sequence_length: int = 512, | 
|  | 152 | +    ): | 
|  | 153 | +        r""" | 
|  | 154 | +        Encodes the prompt into text encoder hidden states. | 
|  | 155 | +
 | 
|  | 156 | +        Args: | 
|  | 157 | +            prompt (`str` or `List[str]`, *optional*): | 
|  | 158 | +                prompt to be encoded | 
|  | 159 | +            device: (`torch.device`): | 
|  | 160 | +                torch device | 
|  | 161 | +            num_videos_per_prompt (`int`): | 
|  | 162 | +                number of videos that should be generated per prompt | 
|  | 163 | +            prepare_unconditional_embeds (`bool`): | 
|  | 164 | +                whether to use prepare unconditional embeddings or not | 
|  | 165 | +            negative_prompt (`str` or `List[str]`, *optional*): | 
|  | 166 | +                The prompt or prompts not to guide the image generation. If not defined, one has to pass | 
|  | 167 | +                `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | 
|  | 168 | +                less than `1`). | 
|  | 169 | +            prompt_embeds (`torch.Tensor`, *optional*): | 
|  | 170 | +                Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | 
|  | 171 | +                provided, text embeddings will be generated from `prompt` input argument. | 
|  | 172 | +            negative_prompt_embeds (`torch.Tensor`, *optional*): | 
|  | 173 | +                Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | 
|  | 174 | +                weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | 
|  | 175 | +                argument. | 
|  | 176 | +            max_sequence_length (`int`, defaults to `512`): | 
|  | 177 | +                The maximum number of text tokens to be used for the generation process. | 
|  | 178 | +        """ | 
|  | 179 | +        device = device or components._execution_device | 
|  | 180 | +        prompt = [prompt] if isinstance(prompt, str) else prompt | 
|  | 181 | +        batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0] | 
|  | 182 | + | 
|  | 183 | +        if prompt_embeds is None: | 
|  | 184 | +            prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device) | 
|  | 185 | + | 
|  | 186 | +        if prepare_unconditional_embeds and negative_prompt_embeds is None: | 
|  | 187 | +            negative_prompt = negative_prompt or "" | 
|  | 188 | +            negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | 
|  | 189 | + | 
|  | 190 | +            if prompt is not None and type(prompt) is not type(negative_prompt): | 
|  | 191 | +                raise TypeError( | 
|  | 192 | +                    f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | 
|  | 193 | +                    f" {type(prompt)}." | 
|  | 194 | +                ) | 
|  | 195 | +            elif batch_size != len(negative_prompt): | 
|  | 196 | +                raise ValueError( | 
|  | 197 | +                    f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | 
|  | 198 | +                    f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | 
|  | 199 | +                    " the batch size of `prompt`." | 
|  | 200 | +                ) | 
|  | 201 | + | 
|  | 202 | +            negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds( | 
|  | 203 | +                components, negative_prompt, max_sequence_length, device | 
|  | 204 | +            ) | 
|  | 205 | + | 
|  | 206 | +        bs_embed, seq_len, _ = prompt_embeds.shape | 
|  | 207 | +        prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | 
|  | 208 | +        prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) | 
|  | 209 | + | 
|  | 210 | +        if prepare_unconditional_embeds: | 
|  | 211 | +            negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) | 
|  | 212 | +            negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | 
|  | 213 | + | 
|  | 214 | +        return prompt_embeds, negative_prompt_embeds | 
|  | 215 | + | 
|  | 216 | +    @torch.no_grad() | 
|  | 217 | +    def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: | 
|  | 218 | +        # Get inputs and intermediates | 
|  | 219 | +        block_state = self.get_block_state(state) | 
|  | 220 | +        self.check_inputs(block_state) | 
|  | 221 | + | 
|  | 222 | +        block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 | 
|  | 223 | +        block_state.device = components._execution_device | 
|  | 224 | + | 
|  | 225 | +        # Encode input prompt | 
|  | 226 | +        ( | 
|  | 227 | +            block_state.prompt_embeds, | 
|  | 228 | +            block_state.negative_prompt_embeds, | 
|  | 229 | +        ) = self.encode_prompt( | 
|  | 230 | +            components, | 
|  | 231 | +            block_state.prompt, | 
|  | 232 | +            block_state.device, | 
|  | 233 | +            1, | 
|  | 234 | +            block_state.prepare_unconditional_embeds, | 
|  | 235 | +            block_state.negative_prompt, | 
|  | 236 | +            prompt_embeds=None, | 
|  | 237 | +            negative_prompt_embeds=None, | 
|  | 238 | +        ) | 
|  | 239 | + | 
|  | 240 | +        # Add outputs | 
|  | 241 | +        self.set_block_state(state, block_state) | 
|  | 242 | +        return components, state | 
0 commit comments