Skip to content

Commit 558eddb

Browse files
Glaceon-Hyyakaitsuki-iiqzzz95
authored
Feature/qwen image edit (#152)
* support qwen image edit * fix Qwen Image Edit resolution * rope max size 4096 -> 10000 * fix --------- Co-authored-by: zhuguoxuan.zgx <[email protected]> Co-authored-by: dujiancong.djc <[email protected]>
1 parent 7e0a4c2 commit 558eddb

File tree

21 files changed

+827
-129
lines changed

21 files changed

+827
-129
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@ dist/
88
*.egg-info/
99
.DS_Store/
1010
.pytest_cache/
11-
.ruff_cache/
11+
.ruff_cache/
12+
CLAUDE.md

diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,6 @@
2121
"vision_start_token_id": 151652,
2222
"vision_end_token_id": 151653,
2323
"image_token_id": 151655,
24-
"video_token_id": 151656
24+
"video_token_id": 151656,
25+
"attn_impl": "sdpa"
2526
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"do_convert_rgb": true,
3+
"do_normalize": true,
4+
"do_rescale": true,
5+
"do_resize": true,
6+
"image_mean": [
7+
0.48145466,
8+
0.4578275,
9+
0.40821073
10+
],
11+
"image_processor_type": "Qwen2VLImageProcessor",
12+
"image_std": [
13+
0.26862954,
14+
0.26130258,
15+
0.27577711
16+
],
17+
"max_pixels": 12845056,
18+
"merge_size": 2,
19+
"min_pixels": 3136,
20+
"patch_size": 14,
21+
"processor_class": "Qwen2_5_VLProcessor",
22+
"resample": 3,
23+
"rescale_factor": 0.00392156862745098,
24+
"size": {
25+
"longest_edge": 12845056,
26+
"shortest_edge": 3136
27+
},
28+
"temporal_patch_size": 2
29+
}

diffsynth_engine/models/basic/attention.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
from einops import rearrange, repeat
45
from typing import Optional
56

6-
import torch.nn.functional as F
77
from diffsynth_engine.utils import logging
88
from diffsynth_engine.utils.flag import (
99
FLASH_ATTN_3_AVAILABLE,
@@ -42,11 +42,11 @@ def xformers_attn(q, k, v, attn_mask=None, scale=None):
4242

4343
if SDPA_AVAILABLE:
4444

45-
def sdpa_attn(q, k, v, attn_mask=None, scale=None):
45+
def sdpa_attn(q, k, v, attn_mask=None, is_causal=False, scale=None):
4646
q = q.transpose(1, 2)
4747
k = k.transpose(1, 2)
4848
v = v.transpose(1, 2)
49-
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=scale)
49+
out = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, is_causal=is_causal, scale=scale)
5050
return out.transpose(1, 2)
5151

5252

diffsynth_engine/models/qwen_image/qwen2_5_vl.py

Lines changed: 41 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from diffsynth_engine.models.base import PreTrainedModel
99
from diffsynth_engine.models.basic.transformer_helper import RMSNorm
10-
from diffsynth_engine.models.basic.attention import attention
10+
from diffsynth_engine.models.basic import attention as attention_ops
1111
from diffsynth_engine.models.utils import no_init_weights
1212
from diffsynth_engine.utils.cache import Cache, DynamicCache
1313
from diffsynth_engine.utils import logging
@@ -152,17 +152,15 @@ def __init__(
152152
self,
153153
dim: int = 80,
154154
theta: float = 10000.0,
155-
device: str = "cuda:0",
156-
dtype: torch.dtype = torch.bfloat16,
157155
):
158156
super().__init__()
159-
with torch.device(device):
160-
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
161-
self.register_buffer("inv_freq", inv_freq, persistent=False)
157+
with torch.device("cpu"):
158+
self.inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
162159

163-
def forward(self, seqlen: int) -> torch.Tensor:
164-
seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
165-
freqs = torch.outer(seq, self.inv_freq)
160+
def forward(self, seqlen: int, device: str) -> torch.Tensor:
161+
inv_freq = self.inv_freq.to(device=device)
162+
seq = torch.arange(seqlen, device=inv_freq.device, dtype=inv_freq.dtype)
163+
freqs = torch.outer(seq, inv_freq)
166164
return freqs
167165

168166

@@ -222,7 +220,7 @@ def forward(
222220
q = rearrange(q, "s n d -> 1 s n d")
223221
k = rearrange(k, "s n d -> 1 s n d")
224222
v = rearrange(v, "s n d -> 1 s n d")
225-
out = attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
223+
out = attention_ops.attention(q, k, v, attn_impl=self.attn_impl, attn_mask=attention_mask)
226224
out = rearrange(out, "1 s n d -> s (n d)")
227225
out = self.proj(out)
228226
return out
@@ -301,7 +299,7 @@ def __init__(self, config: Qwen2_5_VLVisionConfig, device: str = "cuda:0", dtype
301299
dtype=dtype,
302300
)
303301
head_dim = config.hidden_size // config.num_heads
304-
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2, device=device, dtype=dtype)
302+
self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)
305303
self.blocks = nn.ModuleList(
306304
[
307305
Qwen2_5_VisionBlock(
@@ -348,7 +346,7 @@ def rot_pos_emb(self, grid_thw):
348346
pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
349347
pos_ids = torch.cat(pos_ids, dim=0)
350348
max_grid_size = grid_thw[:, 1:].max()
351-
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
349+
rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size, device=grid_thw.device)
352350
rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
353351
return rotary_pos_emb
354352

@@ -488,7 +486,6 @@ def __init__(
488486
hidden_size: int = 3584,
489487
num_attention_heads: int = 28,
490488
num_key_value_heads: int = 4,
491-
# dropout: float = 0.0,
492489
mrope_section: List[int] = [16, 24, 24],
493490
attn_impl: Optional[str] = None,
494491
device: str = "cuda:0",
@@ -501,7 +498,6 @@ def __init__(
501498
self.head_dim = hidden_size // num_attention_heads
502499
self.num_key_value_heads = num_key_value_heads
503500
self.num_key_value_groups = num_attention_heads // num_key_value_heads
504-
# self.dropout = dropout
505501
self.mrope_section = mrope_section
506502
self.attn_impl = attn_impl
507503

@@ -521,8 +517,6 @@ def __init__(
521517
self.num_attention_heads * self.head_dim, self.hidden_size, bias=False, device=device, dtype=dtype
522518
)
523519

524-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=self.head_dim, device=device, dtype=dtype)
525-
526520
def forward(
527521
self,
528522
hidden_states: torch.Tensor,
@@ -556,14 +550,18 @@ def forward(
556550
if attention_mask is not None: # no matter the length, we just slice it
557551
causal_mask = attention_mask[:, :, :, : key_states.shape[1]]
558552

559-
# TODO: attention_mask for flash attention 2
560-
out = attention(
561-
query_states,
562-
key_states,
563-
value_states,
564-
attn_impl=self.attn_impl,
565-
attn_mask=causal_mask,
566-
)
553+
# TODO: use is_causal when attention mask is causal
554+
if self.attn_impl == "sdpa":
555+
out = attention_ops.sdpa_attn(query_states, key_states, value_states, is_causal=True)
556+
else:
557+
# TODO: attention_mask for flash attention 2
558+
out = attention_ops.attention(
559+
query_states,
560+
key_states,
561+
value_states,
562+
attn_impl=self.attn_impl,
563+
attn_mask=causal_mask,
564+
)
567565
out = rearrange(out, "b s n d -> b s (n d)")
568566
out = self.o_proj(out)
569567
return out, past_key_values
@@ -647,29 +645,29 @@ def forward(
647645

648646

649647
class Qwen2_5_VLRotaryEmbedding(nn.Module):
650-
def __init__(self, dim: int = 128, device: str = "cuda:0", dtype: torch.dtype = torch.bfloat16):
648+
def __init__(self, dim: int = 128):
651649
super().__init__()
652-
with torch.device(device):
653-
inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
654-
self.register_buffer("inv_freq", inv_freq, persistent=False)
650+
with torch.device("cpu"):
651+
self.inv_freq = self.compute_rope(dim) # default rope without dynamic frequency
655652

656653
def compute_rope(self, dim: int, theta: float = 1000000.0):
657654
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
658655
return inv_freq
659656

660657
@torch.no_grad()
661-
def forward(self, x, position_ids):
658+
def forward(self, position_ids: torch.LongTensor, device: str, dtype: torch.dtype):
662659
# In contrast to other models, Qwen2_5_VL has different position ids for the grids
663660
# So we expand the inv_freq to shape (3, ...)
664-
inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
661+
inv_freq = self.inv_freq.to(device=device)
662+
inv_freq_expanded = inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
665663
position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
666664

667-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
665+
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(2, 3)
668666
emb = torch.cat((freqs, freqs), dim=-1)
669667
cos = emb.cos()
670668
sin = emb.sin()
671669

672-
return cos.to(device=x.device, dtype=x.dtype), sin.to(device=x.device, dtype=x.dtype)
670+
return cos.to(device=device, dtype=dtype), sin.to(device=device, dtype=dtype)
673671

674672

675673
class Qwen2_5_VLModel(nn.Module):
@@ -702,7 +700,7 @@ def __init__(self, config: Qwen2_5_VLConfig, device: str = "cuda:0", dtype: torc
702700
)
703701
self.norm = Qwen2_5_RMSNorm(config.hidden_size, config.rms_norm_eps, device=device, dtype=dtype)
704702
head_dim = config.hidden_size // config.num_attention_heads
705-
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim, device=device, dtype=dtype)
703+
self.rotary_emb = Qwen2_5_VLRotaryEmbedding(dim=head_dim)
706704

707705
def get_input_embeddings(self):
708706
return self.embed_tokens
@@ -749,7 +747,7 @@ def forward(
749747
hidden_states = inputs_embeds
750748

751749
# create position embeddings to be shared across the decoder layers
752-
position_embeddings = self.rotary_emb(hidden_states, position_ids)
750+
position_embeddings = self.rotary_emb(position_ids, device=hidden_states.device, dtype=hidden_states.dtype)
753751

754752
# decoder layers
755753
for decoder_layer in self.layers:
@@ -940,8 +938,7 @@ def from_state_dict(
940938
with torch.device("meta"), no_init_weights():
941939
model = cls(vision_config=vision_config, config=config, device=device, dtype=dtype)
942940
model.load_state_dict(state_dict, assign=True)
943-
for param in model.parameters(): # skip buffers
944-
param.data = param.data.to(device=device, dtype=dtype, non_blocking=True)
941+
model.to(device=device, dtype=dtype, non_blocking=True)
945942
return model
946943

947944
def get_input_embeddings(self):
@@ -1202,27 +1199,14 @@ def forward(
12021199
if position_ids is None:
12031200
assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D"
12041201
# calculate RoPE index once per generation in the pre-fill stage only
1205-
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None:
1206-
position_ids, rope_deltas = self.get_rope_index(
1207-
input_ids,
1208-
image_grid_thw,
1209-
video_grid_thw,
1210-
second_per_grid_ts,
1211-
attention_mask,
1212-
)
1213-
self.rope_deltas = rope_deltas
1214-
# then use the prev pre-calculated rope-deltas to get the correct position ids
1215-
else:
1216-
batch_size, seq_length, _ = inputs_embeds.shape
1217-
delta = (
1218-
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0
1219-
)
1220-
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1221-
position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1222-
if cache_position is not None: # otherwise `deltas` is an int `0`
1223-
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1224-
position_ids = position_ids.add(delta)
1225-
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1202+
position_ids, rope_deltas = self.get_rope_index(
1203+
input_ids,
1204+
image_grid_thw,
1205+
video_grid_thw,
1206+
second_per_grid_ts,
1207+
attention_mask,
1208+
)
1209+
self.rope_deltas = rope_deltas
12261210

12271211
hidden_states, present_key_values = self.model(
12281212
input_ids=None,

0 commit comments

Comments
 (0)