Skip to content

Commit 9381dd6

Browse files
committed
up
1 parent 45465d4 commit 9381dd6

File tree

7 files changed

+512
-477
lines changed

7 files changed

+512
-477
lines changed

src/diffusers/modular_pipelines/flux/before_denoise.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,9 +180,15 @@ def intermediate_outputs(self) -> List[str]:
180180
OutputParam(
181181
"prompt_embeds",
182182
type_hint=torch.Tensor,
183-
kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
183+
# kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
184184
description="text embeddings used to guide the image generation",
185185
),
186+
OutputParam(
187+
"pooled_prompt_embeds",
188+
type_hint=torch.Tensor,
189+
# kwargs_type="guider_input_fields", # already in intermedites state but declare here again for guider_input_fields
190+
description="pooled text embeddings used to guide the image generation",
191+
),
186192
# TODO: support negative embeddings?
187193
]
188194

Lines changed: 308 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,308 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import html
16+
from typing import List, Optional, Union
17+
18+
import regex as re
19+
import torch
20+
from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
21+
22+
from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
23+
from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers
24+
from ..modular_pipeline import PipelineBlock, PipelineState
25+
from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
26+
from .modular_pipeline import FluxModularPipeline
27+
28+
29+
if is_ftfy_available():
30+
import ftfy
31+
32+
33+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34+
35+
36+
def basic_clean(text):
37+
text = ftfy.fix_text(text)
38+
text = html.unescape(html.unescape(text))
39+
return text.strip()
40+
41+
42+
def whitespace_clean(text):
43+
text = re.sub(r"\s+", " ", text)
44+
text = text.strip()
45+
return text
46+
47+
48+
def prompt_clean(text):
49+
text = whitespace_clean(basic_clean(text))
50+
return text
51+
52+
53+
class FluxTextEncoderStep(PipelineBlock):
54+
model_name = "flux"
55+
56+
@property
57+
def description(self) -> str:
58+
return "Text Encoder step that generate text_embeddings to guide the video generation"
59+
60+
@property
61+
def expected_components(self) -> List[ComponentSpec]:
62+
return [
63+
ComponentSpec("text_encoder", CLIPTextModel),
64+
ComponentSpec("tokenizer", CLIPTokenizer),
65+
ComponentSpec("text_encoder_2", T5EncoderModel),
66+
ComponentSpec("tokenizer_2", T5TokenizerFast),
67+
]
68+
69+
@property
70+
def expected_configs(self) -> List[ConfigSpec]:
71+
return []
72+
73+
@property
74+
def inputs(self) -> List[InputParam]:
75+
return [
76+
InputParam("prompt"),
77+
InputParam("prompt_2"),
78+
InputParam("joint_attention_kwargs"),
79+
]
80+
81+
@property
82+
def intermediate_outputs(self) -> List[OutputParam]:
83+
return [
84+
OutputParam(
85+
"prompt_embeds",
86+
type_hint=torch.Tensor,
87+
# kwargs_type="guider_input_fields",
88+
description="text embeddings used to guide the image generation",
89+
),
90+
OutputParam(
91+
"pooled_prompt_embeds",
92+
type_hint=torch.Tensor,
93+
# kwargs_type="guider_input_fields",
94+
description="pooled text embeddings used to guide the image generation",
95+
),
96+
OutputParam(
97+
"text_ids",
98+
type_hint=torch.Tensor,
99+
# kwargs_type="guider_input_fields",
100+
description="ids from the text sequence for RoPE",
101+
),
102+
]
103+
104+
@staticmethod
105+
def check_inputs(block_state):
106+
for prompt in [block_state.prompt, block_state.prompt_2]:
107+
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
108+
raise ValueError(f"`prompt` or `prompt_2` has to be of type `str` or `list` but is {type(prompt)}")
109+
110+
@staticmethod
111+
def _get_t5_prompt_embeds(
112+
components,
113+
prompt: Union[str, List[str]],
114+
num_images_per_prompt: int,
115+
max_sequence_length: int,
116+
device: torch.device,
117+
):
118+
dtype = components.text_encoder_2.dtype
119+
120+
prompt = [prompt] if isinstance(prompt, str) else prompt
121+
batch_size = len(prompt)
122+
123+
if isinstance(components, TextualInversionLoaderMixin):
124+
prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2)
125+
126+
text_inputs = components.tokenizer_2(
127+
prompt,
128+
padding="max_length",
129+
max_length=max_sequence_length,
130+
truncation=True,
131+
return_length=False,
132+
return_overflowing_tokens=False,
133+
return_tensors="pt",
134+
)
135+
text_input_ids = text_inputs.input_ids
136+
137+
untruncated_ids = components.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
138+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
139+
removed_text = components.tokenizer_2.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1])
140+
logger.warning(
141+
"The following part of your input was truncated because `max_sequence_length` is set to "
142+
f" {max_sequence_length} tokens: {removed_text}"
143+
)
144+
145+
prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
146+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
147+
_, seq_len, _ = prompt_embeds.shape
148+
149+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
150+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
151+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
152+
153+
return prompt_embeds
154+
155+
@staticmethod
156+
def _get_clip_prompt_embeds(
157+
components,
158+
prompt: Union[str, List[str]],
159+
num_images_per_prompt: int,
160+
device: torch.device,
161+
):
162+
prompt = [prompt] if isinstance(prompt, str) else prompt
163+
batch_size = len(prompt)
164+
165+
if isinstance(components, TextualInversionLoaderMixin):
166+
prompt = components.maybe_convert_prompt(prompt, components.tokenizer)
167+
168+
text_inputs = components.tokenizer(
169+
prompt,
170+
padding="max_length",
171+
max_length=components.tokenizer.model_max_length,
172+
truncation=True,
173+
return_overflowing_tokens=False,
174+
return_length=False,
175+
return_tensors="pt",
176+
)
177+
178+
text_input_ids = text_inputs.input_ids
179+
tokenizer_max_length = components.tokenizer.model_max_length
180+
untruncated_ids = components.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
181+
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
182+
removed_text = components.tokenizer.batch_decode(untruncated_ids[:, tokenizer_max_length - 1 : -1])
183+
logger.warning(
184+
"The following part of your input was truncated because CLIP can only handle sequences up to"
185+
f" {tokenizer_max_length} tokens: {removed_text}"
186+
)
187+
prompt_embeds = components.text_encoder(text_input_ids.to(device), output_hidden_states=False)
188+
189+
# Use pooled output of CLIPTextModel
190+
prompt_embeds = prompt_embeds.pooler_output
191+
prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device)
192+
193+
# duplicate text embeddings for each generation per prompt, using mps friendly method
194+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
195+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
196+
197+
return prompt_embeds
198+
199+
@staticmethod
200+
def encode_prompt(
201+
components,
202+
prompt: Union[str, List[str]],
203+
prompt_2: Union[str, List[str]],
204+
device: Optional[torch.device] = None,
205+
num_images_per_prompt: int = 1,
206+
prompt_embeds: Optional[torch.FloatTensor] = None,
207+
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
208+
max_sequence_length: int = 512,
209+
lora_scale: Optional[float] = None,
210+
):
211+
r"""
212+
Encodes the prompt into text encoder hidden states.
213+
214+
Args:
215+
prompt (`str` or `List[str]`, *optional*):
216+
prompt to be encoded
217+
prompt_2 (`str` or `List[str]`, *optional*):
218+
The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
219+
used in all text-encoders
220+
device: (`torch.device`):
221+
torch device
222+
num_images_per_prompt (`int`):
223+
number of images that should be generated per prompt
224+
prompt_embeds (`torch.FloatTensor`, *optional*):
225+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
226+
provided, text embeddings will be generated from `prompt` input argument.
227+
pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
228+
Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
229+
If not provided, pooled text embeddings will be generated from `prompt` input argument.
230+
lora_scale (`float`, *optional*):
231+
A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
232+
"""
233+
device = device or components._execution_device
234+
235+
# set lora scale so that monkey patched LoRA
236+
# function of text encoder can correctly access it
237+
if lora_scale is not None and isinstance(components, FluxLoraLoaderMixin):
238+
components._lora_scale = lora_scale
239+
240+
# dynamically adjust the LoRA scale
241+
if components.text_encoder is not None and USE_PEFT_BACKEND:
242+
scale_lora_layers(components.text_encoder, lora_scale)
243+
if components.text_encoder_2 is not None and USE_PEFT_BACKEND:
244+
scale_lora_layers(components.text_encoder_2, lora_scale)
245+
246+
prompt = [prompt] if isinstance(prompt, str) else prompt
247+
248+
if prompt_embeds is None:
249+
prompt_2 = prompt_2 or prompt
250+
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
251+
252+
# We only use the pooled prompt output from the CLIPTextModel
253+
pooled_prompt_embeds = FluxTextEncoderStep._get_clip_prompt_embeds(
254+
components,
255+
prompt=prompt,
256+
device=device,
257+
num_images_per_prompt=num_images_per_prompt,
258+
)
259+
prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds(
260+
components,
261+
prompt=prompt_2,
262+
num_images_per_prompt=num_images_per_prompt,
263+
max_sequence_length=max_sequence_length,
264+
device=device,
265+
)
266+
267+
if components.text_encoder is not None:
268+
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
269+
# Retrieve the original scale by scaling back the LoRA layers
270+
unscale_lora_layers(components.text_encoder, lora_scale)
271+
272+
if components.text_encoder_2 is not None:
273+
if isinstance(components, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
274+
# Retrieve the original scale by scaling back the LoRA layers
275+
unscale_lora_layers(components.text_encoder_2, lora_scale)
276+
277+
dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16
278+
text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
279+
280+
return prompt_embeds, pooled_prompt_embeds, text_ids
281+
282+
@torch.no_grad()
283+
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
284+
# Get inputs and intermediates
285+
block_state = self.get_block_state(state)
286+
self.check_inputs(block_state)
287+
288+
block_state.device = components._execution_device
289+
290+
# Encode input prompt
291+
block_state.text_encoder_lora_scale = (
292+
block_state.joint_attention_kwargs.get("scale", None)
293+
if block_state.joint_attention_kwargs is not None
294+
else None
295+
)
296+
(block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt(
297+
prompt=block_state.prompt,
298+
prompt_2=None,
299+
prompt_embeds=None,
300+
pooled_prompt_embeds=None,
301+
device=block_state.device,
302+
num_images_per_prompt=1, # hardcoded for now.
303+
lora_scale=block_state.text_encoder_lora_scale,
304+
)
305+
306+
# Add outputs
307+
self.set_block_state(state, block_state)
308+
return components, state

0 commit comments

Comments
 (0)