Skip to content

Commit 78f292e

Browse files
committed
propgate changes for qwenimagedit plus.
1 parent d684d46 commit 78f292e

File tree

3 files changed

+65
-170
lines changed

3 files changed

+65
-170
lines changed

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
2828
_import_structure["pipeline_qwenimage_controlnet_inpaint"] = ["QwenImageControlNetInpaintPipeline"]
2929
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
30-
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
3130
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
31+
_import_structure["pipeline_qwenimage_edit_plus"] = ["QwenImageEditPlusPipeline"]
3232
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
3333
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
3434

src/diffusers/pipelines/qwenimage/pipeline_qwen_utils.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -283,6 +283,65 @@ def encode_prompt(
283283
return prompt_embeds, prompt_embeds_mask
284284

285285

286+
class QwenImageEditPlusPipelineMixin(QwenImageEditPipelineMixin):
287+
def _get_qwen_prompt_embeds(
288+
self,
289+
prompt: Union[str, List[str]] = None,
290+
image: Optional[torch.Tensor] = None,
291+
device: Optional[torch.device] = None,
292+
dtype: Optional[torch.dtype] = None,
293+
):
294+
device = device or self._execution_device
295+
dtype = dtype or self.text_encoder.dtype
296+
297+
prompt = [prompt] if isinstance(prompt, str) else prompt
298+
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
299+
if isinstance(image, list):
300+
base_img_prompt = ""
301+
for i, img in enumerate(image):
302+
base_img_prompt += img_prompt_template.format(i + 1)
303+
elif image is not None:
304+
base_img_prompt = img_prompt_template.format(1)
305+
else:
306+
base_img_prompt = ""
307+
308+
template = self.prompt_template_encode
309+
310+
drop_idx = self.prompt_template_encode_start_idx
311+
txt = [template.format(base_img_prompt + e) for e in prompt]
312+
313+
model_inputs = self.processor(
314+
text=txt,
315+
images=image,
316+
padding=True,
317+
return_tensors="pt",
318+
).to(device)
319+
320+
outputs = self.text_encoder(
321+
input_ids=model_inputs.input_ids,
322+
attention_mask=model_inputs.attention_mask,
323+
pixel_values=model_inputs.pixel_values,
324+
image_grid_thw=model_inputs.image_grid_thw,
325+
output_hidden_states=True,
326+
)
327+
328+
hidden_states = outputs.hidden_states[-1]
329+
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
330+
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
331+
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
332+
max_seq_len = max([e.size(0) for e in split_hidden_states])
333+
prompt_embeds = torch.stack(
334+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
335+
)
336+
encoder_attention_mask = torch.stack(
337+
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
338+
)
339+
340+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
341+
342+
return prompt_embeds, encoder_attention_mask
343+
344+
286345
def calculate_dimensions(target_area, ratio):
287346
width = math.sqrt(target_area * ratio)
288347
height = width / ratio

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_edit_plus.py

Lines changed: 5 additions & 169 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import math
1716
from typing import Any, Callable, Dict, List, Optional, Union
1817

1918
import numpy as np
@@ -28,6 +27,7 @@
2827
from ...utils.torch_utils import randn_tensor
2928
from ..pipeline_utils import DiffusionPipeline
3029
from .pipeline_output import QwenImagePipelineOutput
30+
from .pipeline_qwen_utils import QwenImageEditPlusPipelineMixin, calculate_dimensions
3131

3232

3333
if is_torch_xla_available():
@@ -155,17 +155,7 @@ def retrieve_latents(
155155
raise AttributeError("Could not access latents of provided encoder_output")
156156

157157

158-
def calculate_dimensions(target_area, ratio):
159-
width = math.sqrt(target_area * ratio)
160-
height = width / ratio
161-
162-
width = round(width / 32) * 32
163-
height = round(height / 32) * 32
164-
165-
return width, height
166-
167-
168-
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
158+
class QwenImageEditPlusPipeline(DiffusionPipeline, QwenImageEditPlusPipelineMixin, QwenImageLoraLoaderMixin):
169159
r"""
170160
The Qwen-Image-Edit pipeline for image editing.
171161
@@ -217,114 +207,6 @@ def __init__(
217207
self.prompt_template_encode_start_idx = 64
218208
self.default_sample_size = 128
219209

220-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._extract_masked_hidden
221-
def _extract_masked_hidden(self, hidden_states: torch.Tensor, mask: torch.Tensor):
222-
bool_mask = mask.bool()
223-
valid_lengths = bool_mask.sum(dim=1)
224-
selected = hidden_states[bool_mask]
225-
split_result = torch.split(selected, valid_lengths.tolist(), dim=0)
226-
227-
return split_result
228-
229-
def _get_qwen_prompt_embeds(
230-
self,
231-
prompt: Union[str, List[str]] = None,
232-
image: Optional[torch.Tensor] = None,
233-
device: Optional[torch.device] = None,
234-
dtype: Optional[torch.dtype] = None,
235-
):
236-
device = device or self._execution_device
237-
dtype = dtype or self.text_encoder.dtype
238-
239-
prompt = [prompt] if isinstance(prompt, str) else prompt
240-
img_prompt_template = "Picture {}: <|vision_start|><|image_pad|><|vision_end|>"
241-
if isinstance(image, list):
242-
base_img_prompt = ""
243-
for i, img in enumerate(image):
244-
base_img_prompt += img_prompt_template.format(i + 1)
245-
elif image is not None:
246-
base_img_prompt = img_prompt_template.format(1)
247-
else:
248-
base_img_prompt = ""
249-
250-
template = self.prompt_template_encode
251-
252-
drop_idx = self.prompt_template_encode_start_idx
253-
txt = [template.format(base_img_prompt + e) for e in prompt]
254-
255-
model_inputs = self.processor(
256-
text=txt,
257-
images=image,
258-
padding=True,
259-
return_tensors="pt",
260-
).to(device)
261-
262-
outputs = self.text_encoder(
263-
input_ids=model_inputs.input_ids,
264-
attention_mask=model_inputs.attention_mask,
265-
pixel_values=model_inputs.pixel_values,
266-
image_grid_thw=model_inputs.image_grid_thw,
267-
output_hidden_states=True,
268-
)
269-
270-
hidden_states = outputs.hidden_states[-1]
271-
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
272-
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
273-
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
274-
max_seq_len = max([e.size(0) for e in split_hidden_states])
275-
prompt_embeds = torch.stack(
276-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
277-
)
278-
encoder_attention_mask = torch.stack(
279-
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
280-
)
281-
282-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
283-
284-
return prompt_embeds, encoder_attention_mask
285-
286-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.encode_prompt
287-
def encode_prompt(
288-
self,
289-
prompt: Union[str, List[str]],
290-
image: Optional[torch.Tensor] = None,
291-
device: Optional[torch.device] = None,
292-
num_images_per_prompt: int = 1,
293-
prompt_embeds: Optional[torch.Tensor] = None,
294-
prompt_embeds_mask: Optional[torch.Tensor] = None,
295-
max_sequence_length: int = 1024,
296-
):
297-
r"""
298-
299-
Args:
300-
prompt (`str` or `List[str]`, *optional*):
301-
prompt to be encoded
302-
image (`torch.Tensor`, *optional*):
303-
image to be encoded
304-
device: (`torch.device`):
305-
torch device
306-
num_images_per_prompt (`int`):
307-
number of images that should be generated per prompt
308-
prompt_embeds (`torch.Tensor`, *optional*):
309-
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
310-
provided, text embeddings will be generated from `prompt` input argument.
311-
"""
312-
device = device or self._execution_device
313-
314-
prompt = [prompt] if isinstance(prompt, str) else prompt
315-
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
316-
317-
if prompt_embeds is None:
318-
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
319-
320-
_, seq_len, _ = prompt_embeds.shape
321-
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
322-
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
323-
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
324-
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
325-
326-
return prompt_embeds, prompt_embeds_mask
327-
328210
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline.check_inputs
329211
def check_inputs(
330212
self,
@@ -381,32 +263,6 @@ def check_inputs(
381263
if max_sequence_length is not None and max_sequence_length > 1024:
382264
raise ValueError(f"`max_sequence_length` cannot be greater than 1024 but is {max_sequence_length}")
383265

384-
@staticmethod
385-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._pack_latents
386-
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
387-
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
388-
latents = latents.permute(0, 2, 4, 1, 3, 5)
389-
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
390-
391-
return latents
392-
393-
@staticmethod
394-
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline._unpack_latents
395-
def _unpack_latents(latents, height, width, vae_scale_factor):
396-
batch_size, num_patches, channels = latents.shape
397-
398-
# VAE applies 8x compression on images but we must also account for packing which requires
399-
# latent height and width to be divisible by 2.
400-
height = 2 * (int(height) // (vae_scale_factor * 2))
401-
width = 2 * (int(width) // (vae_scale_factor * 2))
402-
403-
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
404-
latents = latents.permute(0, 3, 1, 4, 2, 5)
405-
406-
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
407-
408-
return latents
409-
410266
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage_edit.QwenImageEditPipeline._encode_vae_image
411267
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
412268
if isinstance(generator, list):
@@ -492,26 +348,6 @@ def prepare_latents(
492348

493349
return latents, image_latents
494350

495-
@property
496-
def guidance_scale(self):
497-
return self._guidance_scale
498-
499-
@property
500-
def attention_kwargs(self):
501-
return self._attention_kwargs
502-
503-
@property
504-
def num_timesteps(self):
505-
return self._num_timesteps
506-
507-
@property
508-
def current_timestep(self):
509-
return self._current_timestep
510-
511-
@property
512-
def interrupt(self):
513-
return self._interrupt
514-
515351
@torch.no_grad()
516352
@replace_example_docstring(EXAMPLE_DOC_STRING)
517353
def __call__(
@@ -628,7 +464,7 @@ def __call__(
628464
returning a tuple, the first element is a list with the generated images.
629465
"""
630466
image_size = image[-1].size if isinstance(image, list) else image.size
631-
calculated_width, calculated_height = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
467+
calculated_width, calculated_height, _ = calculate_dimensions(1024 * 1024, image_size[0] / image_size[1])
632468
height = height or calculated_height
633469
width = width or calculated_width
634470

@@ -674,10 +510,10 @@ def __call__(
674510
vae_images = []
675511
for img in image:
676512
image_width, image_height = img.size
677-
condition_width, condition_height = calculate_dimensions(
513+
condition_width, condition_height, _ = calculate_dimensions(
678514
CONDITION_IMAGE_SIZE, image_width / image_height
679515
)
680-
vae_width, vae_height = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
516+
vae_width, vae_height, _ = calculate_dimensions(VAE_IMAGE_SIZE, image_width / image_height)
681517
condition_image_sizes.append((condition_width, condition_height))
682518
vae_image_sizes.append((vae_width, vae_height))
683519
condition_images.append(self.image_processor.resize(img, condition_height, condition_width))

0 commit comments

Comments
 (0)