Skip to content

Commit f3d1dc4

Browse files
authored
support edit 2511 (#212)
1 parent d82861c commit f3d1dc4

File tree

4 files changed

+63
-18
lines changed

4 files changed

+63
-18
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,14 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
251251
# override OptimizationConfig
252252
fbcache_relative_l1_threshold = 0.009
253253

254-
# svd
255-
use_nunchaku: Optional[bool] = field(default=None, init=False)
256-
use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257-
use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258-
254+
# svd
255+
use_nunchaku: Optional[bool] = field(default=None, init=False)
256+
use_nunchaku_awq: Optional[bool] = field(default=None, init=False)
257+
use_nunchaku_attn: Optional[bool] = field(default=None, init=False)
258+
259+
# for 2511
260+
use_zero_cond_t: bool = False
261+
259262
@classmethod
260263
def basic_config(
261264
cls,
@@ -266,6 +269,7 @@ def basic_config(
266269
parallelism: int = 1,
267270
offload_mode: Optional[str] = None,
268271
offload_to_disk: bool = False,
272+
use_zero_cond_t: bool = False,
269273
) -> "QwenImagePipelineConfig":
270274
return cls(
271275
model_path=model_path,
@@ -277,6 +281,7 @@ def basic_config(
277281
use_fsdp=True if parallelism > 1 else False,
278282
offload_mode=offload_mode,
279283
offload_to_disk=offload_to_disk,
284+
use_zero_cond_t=use_zero_cond_t,
280285
)
281286

282287
def __post_init__(self):

diffsynth_engine/models/basic/attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def sparge_attn(
9494
)
9595
return out.transpose(1, 2)
9696

97+
9798
if AITER_AVAILABLE:
9899
from aiter import flash_attn_func as aiter_flash_attn
99100
from aiter import flash_attn_fp8_pertensor_func as aiter_flash_attn_fp8
@@ -203,15 +204,15 @@ def attention(
203204
)
204205
if attn_mask is not None:
205206
raise RuntimeError("aiter_flash_attn does not support attention mask")
206-
if attn_impl == "aiter" :
207+
if attn_impl == "aiter":
207208
return aiter_flash_attn(q, k, v, softmax_scale=scale)
208209
else:
209210
origin_dtype = q.dtype
210211
q = q.to(dtype=DTYPE_FP8)
211212
k = k.to(dtype=DTYPE_FP8)
212213
v = v.to(dtype=DTYPE_FP8)
213214
out = aiter_flash_attn_fp8(q, k, v, softmax_scale=scale)
214-
return out.to(dtype=origin_dtype)
215+
return out.to(dtype=origin_dtype)
215216
if attn_impl == "fa2":
216217
return flash_attn2(q, k, v, softmax_scale=scale)
217218
if attn_impl == "xformers":

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.nn as nn
33
from typing import Any, Dict, List, Tuple, Union, Optional
44
from einops import rearrange
5+
from math import prod
56

67
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
78
from diffsynth_engine.models.basic import attention as attention_ops
@@ -243,6 +244,7 @@ def __init__(
243244
num_attention_heads: int,
244245
attention_head_dim: int,
245246
eps: float = 1e-6,
247+
zero_cond_t: bool = False,
246248
device: str = "cuda:0",
247249
dtype: torch.dtype = torch.bfloat16,
248250
):
@@ -275,10 +277,30 @@ def __init__(
275277
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
276278
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps, device=device, dtype=dtype)
277279
self.txt_mlp = QwenFeedForward(dim=dim, dim_out=dim, device=device, dtype=dtype)
280+
self.zero_cond_t = zero_cond_t
278281

279-
def _modulate(self, x, mod_params):
282+
def _modulate(self, x, mod_params, index=None):
280283
shift, scale, gate = mod_params.chunk(3, dim=-1)
281-
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
284+
if index is not None:
285+
actual_batch = shift.size(0) // 2
286+
shift_0, shift_1 = shift[:actual_batch], shift[actual_batch:]
287+
scale_0, scale_1 = scale[:actual_batch], scale[actual_batch:]
288+
gate_0, gate_1 = gate[:actual_batch], gate[actual_batch:]
289+
index_expanded = index.unsqueeze(-1)
290+
shift_0_exp = shift_0.unsqueeze(1)
291+
shift_1_exp = shift_1.unsqueeze(1)
292+
scale_0_exp = scale_0.unsqueeze(1)
293+
scale_1_exp = scale_1.unsqueeze(1)
294+
gate_0_exp = gate_0.unsqueeze(1)
295+
gate_1_exp = gate_1.unsqueeze(1)
296+
shift_result = torch.where(index_expanded == 0, shift_0_exp, shift_1_exp)
297+
scale_result = torch.where(index_expanded == 0, scale_0_exp, scale_1_exp)
298+
gate_result = torch.where(index_expanded == 0, gate_0_exp, gate_1_exp)
299+
else:
300+
shift_result = shift.unsqueeze(1)
301+
scale_result = scale.unsqueeze(1)
302+
gate_result = gate.unsqueeze(1)
303+
return x * (1 + scale_result) + shift_result, gate_result
282304

283305
def forward(
284306
self,
@@ -288,12 +310,15 @@ def forward(
288310
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
289311
attn_mask: Optional[torch.Tensor] = None,
290312
attn_kwargs: Optional[Dict[str, Any]] = None,
313+
modulate_index: Optional[List[int]] = None,
291314
) -> Tuple[torch.Tensor, torch.Tensor]:
292315
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
316+
if self.zero_cond_t:
317+
temb = torch.chunk(temb, 2, dim=0)[0]
293318
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
294319

295320
img_normed = self.img_norm1(image)
296-
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn)
321+
img_modulated, img_gate = self._modulate(img_normed, img_mod_attn, modulate_index)
297322

298323
txt_normed = self.txt_norm1(text)
299324
txt_modulated, txt_gate = self._modulate(txt_normed, txt_mod_attn)
@@ -305,12 +330,11 @@ def forward(
305330
attn_mask=attn_mask,
306331
attn_kwargs=attn_kwargs,
307332
)
308-
309333
image = image + img_gate * img_attn_out
310334
text = text + txt_gate * txt_attn_out
311335

312336
img_normed_2 = self.img_norm2(image)
313-
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp)
337+
img_modulated_2, img_gate_2 = self._modulate(img_normed_2, img_mod_mlp, modulate_index)
314338

315339
txt_normed_2 = self.txt_norm2(text)
316340
txt_modulated_2, txt_gate_2 = self._modulate(txt_normed_2, txt_mod_mlp)
@@ -331,6 +355,7 @@ class QwenImageDiT(PreTrainedModel):
331355
def __init__(
332356
self,
333357
num_layers: int = 60,
358+
zero_cond_t: bool = False,
334359
device: str = "cuda:0",
335360
dtype: torch.dtype = torch.bfloat16,
336361
):
@@ -351,6 +376,7 @@ def __init__(
351376
dim=3072,
352377
num_attention_heads=24,
353378
attention_head_dim=128,
379+
zero_cond_t=zero_cond_t,
354380
device=device,
355381
dtype=dtype,
356382
)
@@ -359,6 +385,7 @@ def __init__(
359385
)
360386
self.norm_out = AdaLayerNorm(3072, device=device, dtype=dtype)
361387
self.proj_out = nn.Linear(3072, 64, device=device, dtype=dtype)
388+
self.zero_cond_t = zero_cond_t
362389

363390
def patchify(self, hidden_states):
364391
hidden_states = rearrange(hidden_states, "B C (H P) (W Q) -> B (H W) (C P Q)", P=2, Q=2)
@@ -461,6 +488,9 @@ def forward(
461488
use_cfg=use_cfg,
462489
),
463490
):
491+
if self.zero_cond_t:
492+
timestep = torch.cat([timestep, timestep * 0], dim=0)
493+
modulate_index = None
464494
conditioning = self.time_text_embed(timestep, image.dtype)
465495
video_fhw = [(1, h // 2, w // 2)] # frame, height, width
466496
text_seq_len = text_seq_lens.max().item()
@@ -478,7 +508,12 @@ def forward(
478508
img = self.patchify(img)
479509
image = torch.cat([image, img], dim=1)
480510
video_fhw += [(1, edit_h // 2, edit_w // 2)]
481-
511+
if self.zero_cond_t:
512+
modulate_index = torch.tensor(
513+
[[0] * prod(sample[0]) + [1] * sum([prod(s) for s in sample[1:]]) for sample in [video_fhw]],
514+
device=timestep.device,
515+
dtype=torch.int,
516+
)
482517
rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
483518

484519
image = self.img_in(image)
@@ -510,7 +545,10 @@ def forward(
510545
rotary_emb=rotary_emb,
511546
attn_mask=attn_mask,
512547
attn_kwargs=attn_kwargs,
548+
modulate_index=modulate_index,
513549
)
550+
if self.zero_cond_t:
551+
conditioning = conditioning.chunk(2, dim=0)[0]
514552
image = self.norm_out(image, conditioning)
515553
image = self.proj_out(image)
516554
(image,) = sequence_parallel_unshard((image,), seq_dims=(1,), seq_lens=(image_seq_len,))
@@ -527,8 +565,9 @@ def from_state_dict(
527565
device: str,
528566
dtype: torch.dtype,
529567
num_layers: int = 60,
568+
use_zero_cond_t: bool = False,
530569
):
531-
model = cls(device="meta", dtype=dtype, num_layers=num_layers)
570+
model = cls(device="meta", dtype=dtype, num_layers=num_layers, zero_cond_t=use_zero_cond_t)
532571
model = model.requires_grad_(False)
533572
model.load_state_dict(state_dict, assign=True)
534573
model.to(device=device, dtype=dtype, non_blocking=True)

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import torch
33
import torch.distributed as dist
44
import math
5-
import sys
65
from typing import Callable, List, Dict, Tuple, Optional, Union
76
from tqdm import tqdm
87
from einops import rearrange
@@ -45,7 +44,6 @@
4544
logger = logging.get_logger(__name__)
4645

4746

48-
4947
class QwenImageLoRAConverter(LoRAStateDictConverter):
5048
def _from_diffsynth(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, torch.Tensor]]:
5149
dit_dict = {}
@@ -205,7 +203,7 @@ def _setup_nunchaku_config(
205203
else:
206204
config.use_nunchaku_attn = False
207205
logger.info("Disable nunchaku attention quantization.")
208-
206+
209207
else:
210208
config.use_nunchaku = False
211209

@@ -318,6 +316,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
318316
elif config.use_nunchaku:
319317
if not NUNCHAKU_AVAILABLE:
320318
from diffsynth_engine.utils.flag import NUNCHAKU_IMPORT_ERROR
319+
321320
raise ImportError(NUNCHAKU_IMPORT_ERROR)
322321

323322
from diffsynth_engine.models.qwen_image import QwenImageDiTNunchaku
@@ -337,6 +336,7 @@ def _from_state_dict(cls, state_dicts: QwenImageStateDicts, config: QwenImagePip
337336
state_dicts.model,
338337
device=("cpu" if config.use_fsdp else init_device),
339338
dtype=config.model_dtype,
339+
use_zero_cond_t=config.use_zero_cond_t,
340340
)
341341
if config.use_fp8_linear and not config.use_nunchaku:
342342
enable_fp8_linear(dit)
@@ -704,7 +704,7 @@ def __call__(
704704

705705
context_latents = None
706706
for param in controlnet_params:
707-
self.load_lora(param.model, param.scale, fused=False, save_original_weight=False)
707+
self.load_lora(param.model, param.scale, fused=True, save_original_weight=False)
708708
if param.control_type == QwenImageControlType.in_context:
709709
width, height = param.image.size
710710
self.validate_image_size(height, width, minimum=64, multiple_of=16)

0 commit comments

Comments
 (0)