Skip to content

Commit 0acf4cb

Browse files
authored
Feature/qwen image control (#176)
* support qwen image controlnet * fix qwen image attn * update scheduler interface * del attr not in __init__ after restore scheduler * fix base scheduler * feat: standardize scheduler configuration interface * fix tab
1 parent 33d00d0 commit 0acf4cb

File tree

27 files changed

+594
-194
lines changed

27 files changed

+594
-194
lines changed

diffsynth_engine/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
AttnImpl,
1515
ControlNetParams,
1616
ControlType,
17+
QwenImageControlNetParams,
18+
QwenImageControlType,
19+
LoraConfig,
1720
)
1821
from .pipelines import (
1922
SDImagePipeline,
@@ -58,6 +61,8 @@
5861
"AttnImpl",
5962
"ControlNetParams",
6063
"ControlType",
64+
"QwenImageControlNetParams",
65+
"QwenImageControlType",
6166
"SDImagePipeline",
6267
"SDControlNet",
6368
"SDXLImagePipeline",
@@ -74,6 +79,7 @@
7479
"FluxIPAdapterRefTool",
7580
"FluxReplaceByControlTool",
7681
"FluxReduxRefTool",
82+
"LoraConfig",
7783
"fetch_model",
7884
"fetch_modelscope_model",
7985
"register_fetch_modelscope_model",

diffsynth_engine/algorithm/noise_scheduler/base_scheduler.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,24 @@ def append_zero(x):
66

77

88
class BaseScheduler:
9+
def __init__(self):
10+
self._stored_config = {}
11+
12+
def store_config(self):
13+
self._stored_config = {
14+
config_name: config_value
15+
for config_name, config_value in vars(self).items()
16+
if not config_name.startswith("_")
17+
}
18+
19+
def update_config(self, config_dict):
20+
for config_name, new_value in config_dict.items():
21+
if hasattr(self, config_name):
22+
setattr(self, config_name, new_value)
23+
24+
def restore_config(self):
25+
for config_name, config_value in self._stored_config.items():
26+
setattr(self, config_name, config_value)
27+
928
def schedule(self, num_inference_steps: int):
1029
raise NotImplementedError()

diffsynth_engine/algorithm/noise_scheduler/flow_match/recifited_flow.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,23 @@ class RecifitedFlowScheduler(BaseScheduler):
1212
def __init__(
1313
self,
1414
shift=1.0,
15-
sigma_min=0.001,
16-
sigma_max=1.0,
15+
sigma_min=None,
16+
sigma_max=None,
1717
num_train_timesteps=1000,
1818
use_dynamic_shifting=False,
19+
shift_terminal=None,
20+
exponential_shift_mu=None,
1921
):
22+
super().__init__()
2023
self.shift = shift
2124
self.sigma_min = sigma_min
2225
self.sigma_max = sigma_max
2326
self.num_train_timesteps = num_train_timesteps
2427
self.use_dynamic_shifting = use_dynamic_shifting
28+
self.shift_terminal = shift_terminal
29+
# static mu for distill model
30+
self.exponential_shift_mu = exponential_shift_mu
31+
self.store_config()
2532

2633
def _sigma_to_t(self, sigma):
2734
return sigma * self.num_train_timesteps
@@ -35,21 +42,30 @@ def _time_shift(self, mu: float, sigma: float, t: torch.Tensor):
3542
def _shift_sigma(self, sigma: torch.Tensor, shift: float):
3643
return shift * sigma / (1 + (shift - 1) * sigma)
3744

45+
def _stretch_shift_to_terminal(self, sigma: torch.Tensor):
46+
one_minus_z = 1 - sigma
47+
scale_factor = one_minus_z[-1] / (1 - self.shift_terminal)
48+
return 1 - (one_minus_z / scale_factor)
49+
3850
def schedule(
3951
self,
4052
num_inference_steps: int,
4153
mu: float | None = None,
42-
sigma_min: float | None = None,
43-
sigma_max: float | None = None,
54+
sigma_min: float = 0.001,
55+
sigma_max: float = 1.0,
4456
append_value: float = 0,
4557
):
46-
sigma_min = self.sigma_min if sigma_min is None else sigma_min
47-
sigma_max = self.sigma_max if sigma_max is None else sigma_max
58+
sigma_min = sigma_min if self.sigma_min is None else self.sigma_min
59+
sigma_max = sigma_max if self.sigma_max is None else self.sigma_max
4860
sigmas = torch.linspace(sigma_max, sigma_min, num_inference_steps)
61+
if self.exponential_shift_mu is not None:
62+
mu = self.exponential_shift_mu
4963
if self.use_dynamic_shifting:
5064
sigmas = self._time_shift(mu, 1.0, sigmas) # FLUX
5165
else:
5266
sigmas = self._shift_sigma(sigmas, self.shift)
67+
if self.shift_terminal is not None:
68+
sigmas = self._stretch_shift_to_terminal(sigmas)
5369
timesteps = sigmas * self.num_train_timesteps
5470
sigmas = append(sigmas, append_value)
5571
return sigmas, timesteps

diffsynth_engine/configs/__init__.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,15 @@
1717
WanStateDicts,
1818
WanS2VStateDicts,
1919
QwenImageStateDicts,
20+
LoraConfig,
2021
AttnImpl,
2122
)
22-
from .controlnet import ControlType, ControlNetParams
23+
from .controlnet import (
24+
ControlType,
25+
ControlNetParams,
26+
QwenImageControlNetParams,
27+
QwenImageControlType,
28+
)
2329

2430
__all__ = [
2531
"BaseConfig",
@@ -40,7 +46,10 @@
4046
"WanStateDicts",
4147
"WanS2VStateDicts",
4248
"QwenImageStateDicts",
43-
"AttnImpl",
49+
"QwenImageControlType",
50+
"QwenImageControlNetParams",
4451
"ControlType",
4552
"ControlNetParams",
53+
"LoraConfig",
54+
"AttnImpl",
4655
]

diffsynth_engine/configs/controlnet.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,16 @@ class ControlNetParams:
3434
control_start: float = 0
3535
control_end: float = 1
3636
processor_name: Optional[str] = None # only used for sdxl controlnet union now
37+
38+
39+
class QwenImageControlType(Enum):
40+
eligen = "eligen"
41+
in_context = "in_context"
42+
43+
44+
@dataclass
45+
class QwenImageControlNetParams:
46+
image: ImageType
47+
model: str
48+
control_type: QwenImageControlType
49+
scale: float = 1.0

diffsynth_engine/configs/pipeline.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,9 @@ def init_parallel_config(config: FluxPipelineConfig | QwenImagePipelineConfig |
365365
config.tp_degree = 1
366366
else:
367367
raise ValueError("sp_ulysses_degree and sp_ring_degree must be specified together")
368+
369+
370+
@dataclass
371+
class LoraConfig:
372+
scale: float
373+
scheduler_config: Optional[Dict] = None

diffsynth_engine/models/hunyuan3d/dino_image_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torchvision.transforms as transforms
33
import collections.abc
44
import math
5-
from typing import Optional, Tuple, Dict
5+
from typing import Optional, Dict
66

77
import torch
88
from diffsynth_engine.models.base import PreTrainedModel, StateDictConverter
@@ -112,7 +112,9 @@ class Dinov2SelfAttention(nn.Module):
112112
def __init__(self, hidden_size: int, num_attention_heads: int, qkv_bias: bool) -> None:
113113
super().__init__()
114114
if hidden_size % num_attention_heads != 0:
115-
raise ValueError(f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}.")
115+
raise ValueError(
116+
f"hidden_size {hidden_size} is not a multiple of num_attention_heads {num_attention_heads}."
117+
)
116118

117119
self.num_attention_heads = num_attention_heads
118120
self.attention_head_size = int(hidden_size / num_attention_heads)

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 106 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
import torch.nn as nn
3-
from typing import Any, Dict, Tuple, Union, Optional
3+
from typing import Any, Dict, List, Tuple, Union, Optional
44
from einops import rearrange
55

66
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
@@ -190,7 +190,8 @@ def forward(
190190
self,
191191
image: torch.FloatTensor,
192192
text: torch.FloatTensor,
193-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
193+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
194+
attn_mask: Optional[torch.Tensor] = None,
194195
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
195196
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
196197
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
@@ -206,8 +207,8 @@ def forward(
206207
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
207208
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
208209

209-
if image_rotary_emb is not None:
210-
img_freqs, txt_freqs = image_rotary_emb
210+
if rotary_emb is not None:
211+
img_freqs, txt_freqs = rotary_emb
211212
img_q = apply_rotary_emb_qwen(img_q, img_freqs)
212213
img_k = apply_rotary_emb_qwen(img_k, img_freqs)
213214
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
@@ -221,7 +222,7 @@ def forward(
221222
joint_k = joint_k.transpose(1, 2)
222223
joint_v = joint_v.transpose(1, 2)
223224

224-
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, **self.attn_kwargs)
225+
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **self.attn_kwargs)
225226

226227
joint_attn_out = rearrange(joint_attn_out, "b s h d -> b s (h d)").to(joint_q.dtype)
227228

@@ -285,7 +286,8 @@ def forward(
285286
image: torch.Tensor,
286287
text: torch.Tensor,
287288
temb: torch.Tensor,
288-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
289+
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
290+
attn_mask: Optional[torch.Tensor] = None,
289291
) -> Tuple[torch.Tensor, torch.Tensor]:
290292
img_mod_attn, img_mod_mlp = self.img_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
291293
txt_mod_attn, txt_mod_mlp = self.txt_mod(temb).chunk(2, dim=-1) # [B, 3*dim] each
@@ -299,7 +301,8 @@ def forward(
299301
img_attn_out, txt_attn_out = self.attn(
300302
image=img_modulated,
301303
text=txt_modulated,
302-
image_rotary_emb=image_rotary_emb,
304+
rotary_emb=rotary_emb,
305+
attn_mask=attn_mask,
303306
)
304307

305308
image = image + img_gate * img_attn_out
@@ -368,13 +371,74 @@ def unpatchify(self, hidden_states, height, width):
368371
)
369372
return hidden_states
370373

374+
def process_entity_masks(
375+
self,
376+
text: torch.Tensor,
377+
text_seq_lens: torch.LongTensor,
378+
rotary_emb: Tuple[torch.Tensor, torch.Tensor],
379+
video_fhw: List[Tuple[int, int, int]],
380+
entity_text: List[torch.Tensor],
381+
entity_seq_lens: List[torch.LongTensor],
382+
entity_masks: List[torch.Tensor],
383+
device: str,
384+
dtype: torch.dtype,
385+
):
386+
entity_seq_lens = [seq_lens.max().item() for seq_lens in entity_seq_lens]
387+
text_seq_lens = entity_seq_lens + [text_seq_lens.max().item()]
388+
entity_text = [
389+
self.txt_in(self.txt_norm(text[:, :seq_len])) for text, seq_len in zip(entity_text, entity_seq_lens)
390+
]
391+
text = torch.cat(entity_text + [text], dim=1)
392+
393+
entity_txt_freqs = [self.pos_embed(video_fhw, seq_len, device)[1] for seq_len in entity_seq_lens]
394+
img_freqs, txt_freqs = rotary_emb
395+
txt_freqs = torch.cat(entity_txt_freqs + [txt_freqs], dim=0)
396+
rotary_emb = (img_freqs, txt_freqs)
397+
398+
global_mask = torch.ones_like(entity_masks[0], device=device, dtype=dtype)
399+
patched_masks = [self.patchify(mask) for mask in entity_masks + [global_mask]]
400+
batch_size, image_seq_len = patched_masks[0].shape[:2]
401+
total_seq_len = sum(text_seq_lens) + image_seq_len
402+
attention_mask = torch.ones((batch_size, total_seq_len, total_seq_len), device=device, dtype=torch.bool)
403+
404+
# text-image attention mask
405+
img_start, img_end = sum(text_seq_lens), total_seq_len
406+
cumsum = [0]
407+
for seq_len in text_seq_lens:
408+
cumsum.append(cumsum[-1] + seq_len)
409+
for i, patched_mask in enumerate(patched_masks):
410+
txt_start, txt_end = cumsum[i], cumsum[i + 1]
411+
mask = torch.sum(patched_mask, dim=-1) > 0
412+
mask = mask.unsqueeze(1).repeat(1, text_seq_lens[i], 1)
413+
# text-to-image attention
414+
attention_mask[:, txt_start:txt_end, img_start:img_end] = mask
415+
# image-to-text attention
416+
attention_mask[:, img_start:img_end, txt_start:txt_end] = mask.transpose(1, 2)
417+
# entity text tokens should not attend to each other
418+
for i in range(len(text_seq_lens)):
419+
for j in range(len(text_seq_lens)):
420+
if i == j:
421+
continue
422+
i_start, i_end = cumsum[i], cumsum[i + 1]
423+
j_start, j_end = cumsum[j], cumsum[j + 1]
424+
attention_mask[:, i_start:i_end, j_start:j_end] = False
425+
426+
attn_mask = torch.zeros_like(attention_mask, device=device, dtype=dtype)
427+
attn_mask[~attention_mask] = -torch.inf
428+
attn_mask = attn_mask.unsqueeze(1)
429+
return text, rotary_emb, attn_mask
430+
371431
def forward(
372432
self,
373433
image: torch.Tensor,
374434
edit: torch.Tensor = None,
375-
text: torch.Tensor = None,
376435
timestep: torch.LongTensor = None,
377-
txt_seq_lens: torch.LongTensor = None,
436+
text: torch.Tensor = None,
437+
text_seq_lens: torch.LongTensor = None,
438+
context_latents: Optional[torch.Tensor] = None,
439+
entity_text: Optional[List[torch.Tensor]] = None,
440+
entity_seq_lens: Optional[List[torch.LongTensor]] = None,
441+
entity_masks: Optional[List[torch.Tensor]] = None,
378442
):
379443
h, w = image.shape[-2:]
380444
fp8_linear_enabled = getattr(self, "fp8_linear_enabled", False)
@@ -386,35 +450,59 @@ def forward(
386450
(
387451
image,
388452
edit,
389-
text,
390453
timestep,
391-
txt_seq_lens,
454+
text,
455+
text_seq_lens,
456+
*(entity_text if entity_text is not None else ()),
457+
*(entity_seq_lens if entity_seq_lens is not None else ()),
458+
*(entity_masks if entity_masks is not None else ()),
459+
context_latents,
392460
),
393461
use_cfg=use_cfg,
394462
),
395463
):
396464
conditioning = self.time_text_embed(timestep, image.dtype)
397465
video_fhw = [(1, h // 2, w // 2)] # frame, height, width
398-
max_length = txt_seq_lens.max().item()
466+
text_seq_len = text_seq_lens.max().item()
399467
image = self.patchify(image)
400468
image_seq_len = image.shape[1]
469+
if context_latents is not None:
470+
context_latents = context_latents.to(dtype=image.dtype)
471+
context_latents = self.patchify(context_latents)
472+
image = torch.cat([image, context_latents], dim=1)
473+
video_fhw += [(1, h // 2, w // 2)]
401474
if edit is not None:
402475
edit = edit.to(dtype=image.dtype)
403476
edit = self.patchify(edit)
404477
image = torch.cat([image, edit], dim=1)
405-
video_fhw += video_fhw
478+
video_fhw += [(1, h // 2, w // 2)]
406479

407-
image_rotary_emb = self.pos_embed(video_fhw, max_length, image.device)
480+
rotary_emb = self.pos_embed(video_fhw, text_seq_len, image.device)
408481

409482
image = self.img_in(image)
410-
text = self.txt_in(self.txt_norm(text[:, :max_length]))
483+
text = self.txt_in(self.txt_norm(text[:, :text_seq_len]))
484+
485+
attn_mask = None
486+
if entity_text is not None:
487+
text, rotary_emb, attn_mask = self.process_entity_masks(
488+
text,
489+
text_seq_lens,
490+
rotary_emb,
491+
video_fhw,
492+
entity_text,
493+
entity_seq_lens,
494+
entity_masks,
495+
image.device,
496+
image.dtype,
497+
)
411498

412499
for block in self.transformer_blocks:
413-
text, image = block(image=image, text=text, temb=conditioning, image_rotary_emb=image_rotary_emb)
500+
text, image = block(
501+
image=image, text=text, temb=conditioning, rotary_emb=rotary_emb, attn_mask=attn_mask
502+
)
414503
image = self.norm_out(image, conditioning)
415504
image = self.proj_out(image)
416-
if edit is not None:
417-
image = image[:, :image_seq_len]
505+
image = image[:, :image_seq_len]
418506

419507
image = self.unpatchify(image, h, w)
420508

0 commit comments

Comments
 (0)