Skip to content

Commit d0763b2

Browse files
authored
supports sequence parallel and use custom image size for Qwen Image (#186)
* supports qwen image sequence parallel * use custom image size
1 parent ece6fec commit d0763b2

File tree

2 files changed

+25
-12
lines changed

2 files changed

+25
-12
lines changed

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, ApproximateGELU, RMSNorm
1010
from diffsynth_engine.utils.gguf import gguf_inference
1111
from diffsynth_engine.utils.fp8_linear import fp8_inference
12-
from diffsynth_engine.utils.parallel import cfg_parallel, cfg_parallel_unshard
12+
from diffsynth_engine.utils.parallel import (
13+
cfg_parallel,
14+
cfg_parallel_unshard,
15+
sequence_parallel,
16+
sequence_parallel_unshard,
17+
)
1318

1419

1520
class QwenImageDiTStateDictConverter(StateDictConverter):
@@ -498,14 +503,18 @@ def forward(
498503
image.dtype,
499504
)
500505

501-
for block in self.transformer_blocks:
502-
text, image = block(
503-
image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
504-
)
505-
image = self.norm_out(image, conditioning)
506-
image = self.proj_out(image)
506+
# warning: Eligen does not work with sequence parallel because long context attention does not support attention masks
507+
img_freqs, txt_freqs = rotary_emb
508+
with sequence_parallel((image, text, img_freqs, txt_freqs), seq_dims=(1, 1, 0, 0)):
509+
rotary_emb = (img_freqs, txt_freqs)
510+
for block in self.transformer_blocks:
511+
text, image = block(
512+
image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
513+
)
514+
image = self.norm_out(image, conditioning)
515+
image = self.proj_out(image)
516+
(image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,))
507517
image = image[:, :image_seq_len]
508-
509518
image = self.unpatchify(image, h, w)
510519

511520
(image,) = cfg_parallel_unshard((image,), use_cfg=use_cfg)

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -561,8 +561,8 @@ def __call__(
561561
# single image for edit, list for edit plus(QwenImageEdit2509)
562562
input_image: List[Image.Image] | Image.Image | None = None,
563563
cfg_scale: float = 4.0, # true cfg
564-
height: int = 1328,
565-
width: int = 1328,
564+
height: Optional[int] = None,
565+
width: Optional[int] = None,
566566
num_inference_steps: int = 50,
567567
seed: int | None = None,
568568
controlnet_params: List[QwenImageControlNetParams] | QwenImageControlNetParams = [],
@@ -571,7 +571,9 @@ def __call__(
571571
entity_prompts: Optional[List[str]] = None,
572572
entity_masks: Optional[List[Image.Image]] = None,
573573
):
574+
assert (height is None) == (width is None), "height and width should be set together"
574575
is_edit_plus = isinstance(input_image, list)
576+
575577
if input_image is not None:
576578
if not isinstance(input_image, list):
577579
input_image = [input_image]
@@ -583,9 +585,11 @@ def __call__(
583585
vae_width, vae_height = self.calculate_dimensions(1024 * 1024, img_width / img_height)
584586
condition_images.append(img.resize((condition_width, condition_height), Image.LANCZOS))
585587
vae_images.append(img.resize((vae_width, vae_height), Image.LANCZOS))
588+
if width is None and height is None:
589+
width, height = vae_images[-1].size
586590

587-
width, height = vae_images[-1].size
588-
591+
if width is None and height is None:
592+
width, height = 1328, 1328
589593
self.validate_image_size(height, width, minimum=64, multiple_of=16)
590594

591595
if not isinstance(controlnet_params, list):

0 commit comments

Comments
 (0)