Skip to content

Commit 435a8c0

Browse files
committed
up
1 parent d7ef6a0 commit 435a8c0

File tree

1 file changed

+68
-85
lines changed

1 file changed

+68
-85
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwen_utils.py

Lines changed: 68 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,74 @@ def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor
102102

103103
return split_result
104104

105+
def _get_qwen_prompt_embeds(
106+
self,
107+
prompt: Union[str, List[str]],
108+
image: Optional[torch.Tensor] = None,
109+
device: Optional[torch.device] = None,
110+
dtype: Optional[torch.dtype] = None,
111+
):
112+
device = device or self._execution_device
113+
dtype = dtype or self.text_encoder.dtype
114+
prompt = [prompt] if isinstance(prompt, str) else prompt
115+
116+
template = self.prompt_template_encode
117+
drop_idx = self.prompt_template_encode_start_idx
118+
txt = [template.format(e) for e in prompt]
119+
use_multimodal = image is not None and hasattr(self, "processor")
120+
121+
if use_multimodal:
122+
# --- Multimodal (text+image) ---
123+
model_inputs = self.processor(
124+
text=txt,
125+
images=image,
126+
padding=True,
127+
return_tensors="pt",
128+
).to(device)
129+
130+
outputs = self.text_encoder(
131+
input_ids=model_inputs.input_ids,
132+
attention_mask=model_inputs.attention_mask,
133+
pixel_values=model_inputs.pixel_values,
134+
image_grid_thw=model_inputs.image_grid_thw,
135+
output_hidden_states=True,
136+
)
137+
hidden_states = outputs.hidden_states[-1]
138+
attn_mask = model_inputs.attention_mask
139+
else:
140+
# --- Text-only ---
141+
txt_tokens = self.tokenizer(
142+
txt,
143+
max_length=self.tokenizer_max_length + drop_idx,
144+
padding=True,
145+
truncation=True,
146+
return_tensors="pt",
147+
).to(device)
148+
149+
outputs = self.text_encoder(
150+
input_ids=txt_tokens.input_ids,
151+
attention_mask=txt_tokens.attention_mask,
152+
output_hidden_states=True,
153+
)
154+
hidden_states = outputs.hidden_states[-1]
155+
attn_mask = txt_tokens.attention_mask
156+
157+
split_hidden_states = self._extract_masked_hidden(hidden_states, attn_mask)
158+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
159+
160+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
161+
max_seq_len = max(e.size(0) for e in split_hidden_states)
162+
163+
prompt_embeds = torch.stack(
164+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
165+
)
166+
encoder_attention_mask = torch.stack(
167+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
168+
)
169+
170+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
171+
return prompt_embeds, encoder_attention_mask
172+
105173
@staticmethod
106174
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
107175
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
@@ -171,44 +239,6 @@ def encode_prompt(
171239

172240
return prompt_embeds, prompt_embeds_mask
173241

174-
def _get_qwen_prompt_embeds(
175-
self,
176-
prompt: Union[str, List[str]] = None,
177-
device: Optional[torch.device] = None,
178-
dtype: Optional[torch.dtype] = None,
179-
):
180-
device = device or self._execution_device
181-
dtype = dtype or self.text_encoder.dtype
182-
183-
prompt = [prompt] if isinstance(prompt, str) else prompt
184-
185-
template = self.prompt_template_encode
186-
drop_idx = self.prompt_template_encode_start_idx
187-
txt = [template.format(e) for e in prompt]
188-
txt_tokens = self.tokenizer(
189-
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
190-
).to(device)
191-
encoder_hidden_states = self.text_encoder(
192-
input_ids=txt_tokens.input_ids,
193-
attention_mask=txt_tokens.attention_mask,
194-
output_hidden_states=True,
195-
)
196-
hidden_states = encoder_hidden_states.hidden_states[-1]
197-
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
198-
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
199-
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
200-
max_seq_len = max([e.size(0) for e in split_hidden_states])
201-
prompt_embeds = torch.stack(
202-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
203-
)
204-
encoder_attention_mask = torch.stack(
205-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
206-
)
207-
208-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
209-
210-
return prompt_embeds, encoder_attention_mask
211-
212242

213243
class QwenImageEditPipelineMixin(QwenImageMixin):
214244
def encode_prompt(
@@ -252,53 +282,6 @@ def encode_prompt(
252282

253283
return prompt_embeds, prompt_embeds_mask
254284

255-
def _get_qwen_prompt_embeds(
256-
self,
257-
prompt: Union[str, List[str]] = None,
258-
image: Optional[torch.Tensor] = None,
259-
device: Optional[torch.device] = None,
260-
dtype: Optional[torch.dtype] = None,
261-
):
262-
device = device or self._execution_device
263-
dtype = dtype or self.text_encoder.dtype
264-
265-
prompt = [prompt] if isinstance(prompt, str) else prompt
266-
267-
template = self.prompt_template_encode
268-
drop_idx = self.prompt_template_encode_start_idx
269-
txt = [template.format(e) for e in prompt]
270-
271-
model_inputs = self.processor(
272-
text=txt,
273-
images=image,
274-
padding=True,
275-
return_tensors="pt",
276-
).to(device)
277-
278-
outputs = self.text_encoder(
279-
input_ids=model_inputs.input_ids,
280-
attention_mask=model_inputs.attention_mask,
281-
pixel_values=model_inputs.pixel_values,
282-
image_grid_thw=model_inputs.image_grid_thw,
283-
output_hidden_states=True,
284-
)
285-
286-
hidden_states = outputs.hidden_states[-1]
287-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
288-
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
289-
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
290-
max_seq_len = max([e.size(0) for e in split_hidden_states])
291-
prompt_embeds = torch.stack(
292-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
293-
)
294-
encoder_attention_mask = torch.stack(
295-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
296-
)
297-
298-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
299-
300-
return prompt_embeds, encoder_attention_mask
301-
302285

303286
def calculate_dimensions(target_area, ratio):
304287
width = math.sqrt(target_area * ratio)

0 commit comments

Comments
 (0)