Skip to content

Commit 3c278c0

Browse files
committed
get back the deleted files.
1 parent 22e8cb4 commit 3c278c0

File tree

3 files changed

+476
-0
lines changed

3 files changed

+476
-0
lines changed
Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import html
16+
from typing import List, Optional, Union
17+
18+
import regex as re
19+
import torch
20+
from transformers import AutoTokenizer, UMT5EncoderModel
21+
22+
from ...configuration_utils import FrozenDict
23+
from ...guiders import ClassifierFreeGuidance
24+
from ...utils import is_ftfy_available, logging
25+
from ..modular_pipeline import PipelineBlock, PipelineState
26+
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
27+
from .modular_pipeline import WanModularPipeline
28+
29+
30+
if is_ftfy_available():
31+
import ftfy
32+
33+
34+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35+
36+
37+
def basic_clean(text):
38+
text = ftfy.fix_text(text)
39+
text = html.unescape(html.unescape(text))
40+
return text.strip()
41+
42+
43+
def whitespace_clean(text):
44+
text = re.sub(r"\s+", " ", text)
45+
text = text.strip()
46+
return text
47+
48+
49+
def prompt_clean(text):
50+
text = whitespace_clean(basic_clean(text))
51+
return text
52+
53+
54+
class WanTextEncoderStep(PipelineBlock):
55+
model_name = "wan"
56+
57+
@property
58+
def description(self) -> str:
59+
return "Text Encoder step that generate text_embeddings to guide the video generation"
60+
61+
@property
62+
def expected_components(self) -> List[ComponentSpec]:
63+
return [
64+
ComponentSpec("text_encoder", UMT5EncoderModel),
65+
ComponentSpec("tokenizer", AutoTokenizer),
66+
ComponentSpec(
67+
"guider",
68+
ClassifierFreeGuidance,
69+
config=FrozenDict({"guidance_scale": 5.0}),
70+
default_creation_method="from_config",
71+
),
72+
]
73+
74+
@property
75+
def expected_configs(self) -> List[ConfigSpec]:
76+
return []
77+
78+
@property
79+
def inputs(self) -> List[InputParam]:
80+
return [
81+
InputParam("prompt"),
82+
InputParam("negative_prompt"),
83+
InputParam("attention_kwargs"),
84+
]
85+
86+
@property
87+
def intermediate_outputs(self) -> List[OutputParam]:
88+
return [
89+
OutputParam(
90+
"prompt_embeds",
91+
type_hint=torch.Tensor,
92+
kwargs_type="guider_input_fields",
93+
description="text embeddings used to guide the image generation",
94+
),
95+
OutputParam(
96+
"negative_prompt_embeds",
97+
type_hint=torch.Tensor,
98+
kwargs_type="guider_input_fields",
99+
description="negative text embeddings used to guide the image generation",
100+
),
101+
]
102+
103+
@staticmethod
104+
def check_inputs(block_state):
105+
if block_state.prompt is not None and (
106+
not isinstance(block_state.prompt, str) and not isinstance(block_state.prompt, list)
107+
):
108+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}")
109+
110+
@staticmethod
111+
def _get_t5_prompt_embeds(
112+
components,
113+
prompt: Union[str, List[str]],
114+
max_sequence_length: int,
115+
device: torch.device,
116+
):
117+
dtype = components.text_encoder.dtype
118+
prompt = [prompt] if isinstance(prompt, str) else prompt
119+
prompt = [prompt_clean(u) for u in prompt]
120+
121+
text_inputs = components.tokenizer(
122+
prompt,
123+
padding="max_length",
124+
max_length=max_sequence_length,
125+
truncation=True,
126+
add_special_tokens=True,
127+
return_attention_mask=True,
128+
return_tensors="pt",
129+
)
130+
text_input_ids, mask = text_inputs.input_ids, text_inputs.attention_mask
131+
seq_lens = mask.gt(0).sum(dim=1).long()
132+
prompt_embeds = components.text_encoder(text_input_ids.to(device), mask.to(device)).last_hidden_state
133+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
134+
prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)]
135+
prompt_embeds = torch.stack(
136+
[torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0
137+
)
138+
139+
return prompt_embeds
140+
141+
@staticmethod
142+
def encode_prompt(
143+
components,
144+
prompt: str,
145+
device: Optional[torch.device] = None,
146+
num_videos_per_prompt: int = 1,
147+
prepare_unconditional_embeds: bool = True,
148+
negative_prompt: Optional[str] = None,
149+
prompt_embeds: Optional[torch.Tensor] = None,
150+
negative_prompt_embeds: Optional[torch.Tensor] = None,
151+
max_sequence_length: int = 512,
152+
):
153+
r"""
154+
Encodes the prompt into text encoder hidden states.
155+
156+
Args:
157+
prompt (`str` or `List[str]`, *optional*):
158+
prompt to be encoded
159+
device: (`torch.device`):
160+
torch device
161+
num_videos_per_prompt (`int`):
162+
number of videos that should be generated per prompt
163+
prepare_unconditional_embeds (`bool`):
164+
whether to use prepare unconditional embeddings or not
165+
negative_prompt (`str` or `List[str]`, *optional*):
166+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
167+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
168+
less than `1`).
169+
prompt_embeds (`torch.Tensor`, *optional*):
170+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
171+
provided, text embeddings will be generated from `prompt` input argument.
172+
negative_prompt_embeds (`torch.Tensor`, *optional*):
173+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
174+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
175+
argument.
176+
max_sequence_length (`int`, defaults to `512`):
177+
The maximum number of text tokens to be used for the generation process.
178+
"""
179+
device = device or components._execution_device
180+
prompt = [prompt] if isinstance(prompt, str) else prompt
181+
batch_size = len(prompt) if prompt is not None else prompt_embeds.shape[0]
182+
183+
if prompt_embeds is None:
184+
prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(components, prompt, max_sequence_length, device)
185+
186+
if prepare_unconditional_embeds and negative_prompt_embeds is None:
187+
negative_prompt = negative_prompt or ""
188+
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
189+
190+
if prompt is not None and type(prompt) is not type(negative_prompt):
191+
raise TypeError(
192+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
193+
f" {type(prompt)}."
194+
)
195+
elif batch_size != len(negative_prompt):
196+
raise ValueError(
197+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
198+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
199+
" the batch size of `prompt`."
200+
)
201+
202+
negative_prompt_embeds = WanTextEncoderStep._get_t5_prompt_embeds(
203+
components, negative_prompt, max_sequence_length, device
204+
)
205+
206+
bs_embed, seq_len, _ = prompt_embeds.shape
207+
prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1)
208+
prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1)
209+
210+
if prepare_unconditional_embeds:
211+
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1)
212+
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1)
213+
214+
return prompt_embeds, negative_prompt_embeds
215+
216+
@torch.no_grad()
217+
def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState:
218+
# Get inputs and intermediates
219+
block_state = self.get_block_state(state)
220+
self.check_inputs(block_state)
221+
222+
block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
223+
block_state.device = components._execution_device
224+
225+
# Encode input prompt
226+
(
227+
block_state.prompt_embeds,
228+
block_state.negative_prompt_embeds,
229+
) = self.encode_prompt(
230+
components,
231+
block_state.prompt,
232+
block_state.device,
233+
1,
234+
block_state.prepare_unconditional_embeds,
235+
block_state.negative_prompt,
236+
prompt_embeds=None,
237+
negative_prompt_embeds=None,
238+
)
239+
240+
# Add outputs
241+
self.set_block_state(state, block_state)
242+
return components, state
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ...utils import logging
16+
from ..modular_pipeline import AutoPipelineBlocks, SequentialPipelineBlocks
17+
from ..modular_pipeline_utils import InsertableDict
18+
from .before_denoise import (
19+
WanInputStep,
20+
WanPrepareLatentsStep,
21+
WanSetTimestepsStep,
22+
)
23+
from .decoders import WanDecodeStep
24+
from .denoise import WanDenoiseStep
25+
from .encoders import WanTextEncoderStep
26+
27+
28+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29+
30+
31+
# before_denoise: text2vid
32+
class WanBeforeDenoiseStep(SequentialPipelineBlocks):
33+
block_classes = [
34+
WanInputStep,
35+
WanSetTimestepsStep,
36+
WanPrepareLatentsStep,
37+
]
38+
block_names = ["input", "set_timesteps", "prepare_latents"]
39+
40+
@property
41+
def description(self):
42+
return (
43+
"Before denoise step that prepare the inputs for the denoise step.\n"
44+
+ "This is a sequential pipeline blocks:\n"
45+
+ " - `WanInputStep` is used to adjust the batch size of the model inputs\n"
46+
+ " - `WanSetTimestepsStep` is used to set the timesteps\n"
47+
+ " - `WanPrepareLatentsStep` is used to prepare the latents\n"
48+
)
49+
50+
51+
# before_denoise: all task (text2vid,)
52+
class WanAutoBeforeDenoiseStep(AutoPipelineBlocks):
53+
block_classes = [
54+
WanBeforeDenoiseStep,
55+
]
56+
block_names = ["text2vid"]
57+
block_trigger_inputs = [None]
58+
59+
@property
60+
def description(self):
61+
return (
62+
"Before denoise step that prepare the inputs for the denoise step.\n"
63+
+ "This is an auto pipeline block that works for text2vid.\n"
64+
+ " - `WanBeforeDenoiseStep` (text2vid) is used.\n"
65+
)
66+
67+
68+
# denoise: text2vid
69+
class WanAutoDenoiseStep(AutoPipelineBlocks):
70+
block_classes = [
71+
WanDenoiseStep,
72+
]
73+
block_names = ["denoise"]
74+
block_trigger_inputs = [None]
75+
76+
@property
77+
def description(self) -> str:
78+
return (
79+
"Denoise step that iteratively denoise the latents. "
80+
"This is a auto pipeline block that works for text2vid tasks.."
81+
" - `WanDenoiseStep` (denoise) for text2vid tasks."
82+
)
83+
84+
85+
# decode: all task (text2img, img2img, inpainting)
86+
class WanAutoDecodeStep(AutoPipelineBlocks):
87+
block_classes = [WanDecodeStep]
88+
block_names = ["non-inpaint"]
89+
block_trigger_inputs = [None]
90+
91+
@property
92+
def description(self):
93+
return "Decode step that decode the denoised latents into videos outputs.\n - `WanDecodeStep`"
94+
95+
96+
# text2vid
97+
class WanAutoBlocks(SequentialPipelineBlocks):
98+
block_classes = [
99+
WanTextEncoderStep,
100+
WanAutoBeforeDenoiseStep,
101+
WanAutoDenoiseStep,
102+
WanAutoDecodeStep,
103+
]
104+
block_names = [
105+
"text_encoder",
106+
"before_denoise",
107+
"denoise",
108+
"decoder",
109+
]
110+
111+
@property
112+
def description(self):
113+
return (
114+
"Auto Modular pipeline for text-to-video using Wan.\n"
115+
+ "- for text-to-video generation, all you need to provide is `prompt`"
116+
)
117+
118+
119+
TEXT2VIDEO_BLOCKS = InsertableDict(
120+
[
121+
("text_encoder", WanTextEncoderStep),
122+
("input", WanInputStep),
123+
("set_timesteps", WanSetTimestepsStep),
124+
("prepare_latents", WanPrepareLatentsStep),
125+
("denoise", WanDenoiseStep),
126+
("decode", WanDecodeStep),
127+
]
128+
)
129+
130+
131+
AUTO_BLOCKS = InsertableDict(
132+
[
133+
("text_encoder", WanTextEncoderStep),
134+
("before_denoise", WanAutoBeforeDenoiseStep),
135+
("denoise", WanAutoDenoiseStep),
136+
("decode", WanAutoDecodeStep),
137+
]
138+
)
139+
140+
141+
ALL_BLOCKS = {
142+
"text2video": TEXT2VIDEO_BLOCKS,
143+
"auto": AUTO_BLOCKS,
144+
}

0 commit comments

Comments
 (0)