Skip to content

Commit 0590444

Browse files
authored
Fix/qwen image (#197)
* several fixes for qwen image * fix gelu * fix batch_cfg with padding
1 parent 4ae8f2c commit 0590444

File tree

7 files changed

+134
-35
lines changed

7 files changed

+134
-35
lines changed

diffsynth_engine/configs/pipeline.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -262,16 +262,11 @@ class QwenImagePipelineConfig(AttentionConfig, OptimizationConfig, ParallelConfi
262262
encoder_dtype: torch.dtype = torch.bfloat16
263263
vae_dtype: torch.dtype = torch.float32
264264

265+
load_encoder: bool = True
266+
265267
# override OptimizationConfig
266268
fbcache_relative_l1_threshold = 0.009
267269

268-
# override BaseConfig
269-
vae_tiled: bool = True
270-
vae_tile_size: Tuple[int, int] = (34, 34)
271-
vae_tile_stride: Tuple[int, int] = (18, 16)
272-
273-
load_encoder: bool = True
274-
275270
@classmethod
276271
def basic_config(
277272
cls,

diffsynth_engine/models/basic/transformer_helper.py

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
import math
45

56

@@ -91,8 +92,8 @@ class NewGELUActivation(nn.Module):
9192
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
9293
"""
9394

94-
def forward(self, input: "torch.Tensor") -> "torch.Tensor":
95-
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
95+
def forward(self, x: torch.Tensor) -> torch.Tensor:
96+
return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
9697

9798

9899
class ApproximateGELU(nn.Module):
@@ -115,3 +116,36 @@ def __init__(
115116
def forward(self, x: torch.Tensor) -> torch.Tensor:
116117
x = self.proj(x)
117118
return x * torch.sigmoid(1.702 * x)
119+
120+
121+
class GELU(nn.Module):
122+
r"""
123+
GELU activation function with tanh approximation support with `approximate="tanh"`.
124+
125+
Parameters:
126+
dim_in (`int`): The number of channels in the input.
127+
dim_out (`int`): The number of channels in the output.
128+
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation.
129+
bias (`bool`, defaults to True): Whether to use a bias in the linear layer.
130+
"""
131+
132+
def __init__(
133+
self,
134+
dim_in: int,
135+
dim_out: int,
136+
approximate: str = "none",
137+
bias: bool = True,
138+
device: str = "cuda:0",
139+
dtype: torch.dtype = torch.float16,
140+
):
141+
super().__init__()
142+
self.proj = nn.Linear(dim_in, dim_out, bias=bias, device=device, dtype=dtype)
143+
self.approximate = approximate
144+
145+
def gelu(self, gate: torch.Tensor) -> torch.Tensor:
146+
return F.gelu(gate, approximate=self.approximate)
147+
148+
def forward(self, x: torch.Tensor) -> torch.Tensor:
149+
x = self.proj(x)
150+
x = self.gelu(x)
151+
return x

diffsynth_engine/models/qwen_image/qwen_image_dit.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from diffsynth_engine.models.base import StateDictConverter, PreTrainedModel
77
from diffsynth_engine.models.basic import attention as attention_ops
88
from diffsynth_engine.models.basic.timestep import TimestepEmbeddings
9-
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, ApproximateGELU, RMSNorm
9+
from diffsynth_engine.models.basic.transformer_helper import AdaLayerNorm, GELU, RMSNorm
1010
from diffsynth_engine.utils.gguf import gguf_inference
1111
from diffsynth_engine.utils.fp8_linear import fp8_inference
1212
from diffsynth_engine.utils.parallel import (
@@ -144,7 +144,7 @@ def __init__(
144144
super().__init__()
145145
inner_dim = int(dim * 4)
146146
self.net = nn.ModuleList([])
147-
self.net.append(ApproximateGELU(dim, inner_dim, device=device, dtype=dtype))
147+
self.net.append(GELU(dim, inner_dim, approximate="tanh", device=device, dtype=dtype))
148148
self.net.append(nn.Dropout(dropout))
149149
self.net.append(nn.Linear(inner_dim, dim_out, device=device, dtype=dtype))
150150

@@ -155,8 +155,8 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
155155

156156

157157
def apply_rotary_emb_qwen(x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]]):
158-
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
159-
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
158+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) # (b, s, h, d) -> (b, s, h, d/2, 2)
159+
x_out = torch.view_as_real(x_rotated * freqs_cis.unsqueeze(1)).flatten(3) # (b, s, h, d/2, 2) -> (b, s, h, d)
160160
return x_out.type_as(x)
161161

162162

@@ -200,13 +200,13 @@ def forward(
200200
img_q, img_k, img_v = self.to_q(image), self.to_k(image), self.to_v(image)
201201
txt_q, txt_k, txt_v = self.add_q_proj(text), self.add_k_proj(text), self.add_v_proj(text)
202202

203-
img_q = rearrange(img_q, "b s (h d) -> b h s d", h=self.num_heads)
204-
img_k = rearrange(img_k, "b s (h d) -> b h s d", h=self.num_heads)
205-
img_v = rearrange(img_v, "b s (h d) -> b h s d", h=self.num_heads)
203+
img_q = rearrange(img_q, "b s (h d) -> b s h d", h=self.num_heads)
204+
img_k = rearrange(img_k, "b s (h d) -> b s h d", h=self.num_heads)
205+
img_v = rearrange(img_v, "b s (h d) -> b s h d", h=self.num_heads)
206206

207-
txt_q = rearrange(txt_q, "b s (h d) -> b h s d", h=self.num_heads)
208-
txt_k = rearrange(txt_k, "b s (h d) -> b h s d", h=self.num_heads)
209-
txt_v = rearrange(txt_v, "b s (h d) -> b h s d", h=self.num_heads)
207+
txt_q = rearrange(txt_q, "b s (h d) -> b s h d", h=self.num_heads)
208+
txt_k = rearrange(txt_k, "b s (h d) -> b s h d", h=self.num_heads)
209+
txt_v = rearrange(txt_v, "b s (h d) -> b s h d", h=self.num_heads)
210210

211211
img_q, img_k = self.norm_q(img_q), self.norm_k(img_k)
212212
txt_q, txt_k = self.norm_added_q(txt_q), self.norm_added_k(txt_k)
@@ -218,13 +218,9 @@ def forward(
218218
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs)
219219
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs)
220220

221-
joint_q = torch.cat([txt_q, img_q], dim=2)
222-
joint_k = torch.cat([txt_k, img_k], dim=2)
223-
joint_v = torch.cat([txt_v, img_v], dim=2)
224-
225-
joint_q = joint_q.transpose(1, 2)
226-
joint_k = joint_k.transpose(1, 2)
227-
joint_v = joint_v.transpose(1, 2)
221+
joint_q = torch.cat([txt_q, img_q], dim=1)
222+
joint_k = torch.cat([txt_k, img_k], dim=1)
223+
joint_v = torch.cat([txt_v, img_v], dim=1)
228224

229225
attn_kwargs = attn_kwargs if attn_kwargs is not None else {}
230226
joint_attn_out = attention_ops.attention(joint_q, joint_k, joint_v, attn_mask=attn_mask, **attn_kwargs)

diffsynth_engine/pipelines/qwen_image.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from diffsynth_engine.models.qwen_image import QwenImageVAE
2525
from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
2626
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
27-
from diffsynth_engine.pipelines.utils import calculate_shift
27+
from diffsynth_engine.pipelines.utils import calculate_shift, pad_and_concat
2828
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
2929
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
3030
from diffsynth_engine.utils.constants import (
@@ -148,9 +148,17 @@ def __init__(
148148
self.prompt_template_encode_start_idx = 34
149149
# qwen image edit
150150
self.edit_system_prompt = "Describe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate."
151-
self.edit_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
151+
self.edit_prompt_template_encode = (
152+
"<|im_start|>system\n"
153+
+ self.edit_system_prompt
154+
+ "<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
155+
)
152156
# qwen image edit plus
153-
self.edit_plus_prompt_template_encode = "<|im_start|>system\n" + self.edit_system_prompt + "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
157+
self.edit_plus_prompt_template_encode = (
158+
"<|im_start|>system\n"
159+
+ self.edit_system_prompt
160+
+ "<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
161+
)
154162

155163
self.edit_prompt_template_encode_start_idx = 64
156164

@@ -490,8 +498,8 @@ def predict_noise_with_cfg(
490498
else:
491499
# cfg by predict noise in one batch
492500
bs, _, h, w = latents.shape
493-
prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
494-
prompt_emb_mask = torch.cat([prompt_emb_mask, negative_prompt_emb_mask], dim=0)
501+
prompt_emb = pad_and_concat(prompt_emb, negative_prompt_emb)
502+
prompt_emb_mask = pad_and_concat(prompt_emb_mask, negative_prompt_emb_mask)
495503
if entity_prompt_embs is not None:
496504
entity_prompt_embs = [
497505
torch.cat([x, y], dim=0) for x, y in zip(entity_prompt_embs, negative_entity_prompt_embs)

diffsynth_engine/pipelines/utils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
import torch
2+
import torch.nn.functional as F
3+
4+
15
def accumulate(result, new_item):
26
if result is None:
37
return new_item
@@ -17,3 +21,51 @@ def calculate_shift(
1721
b = base_shift - m * base_seq_len
1822
mu = image_seq_len * m + b
1923
return mu
24+
25+
26+
def pad_and_concat(
27+
tensor1: torch.Tensor,
28+
tensor2: torch.Tensor,
29+
concat_dim: int = 0,
30+
pad_dim: int = 1,
31+
) -> torch.Tensor:
32+
"""
33+
Concatenate two tensors along a specified dimension after padding along another dimension.
34+
35+
Assumes input tensors have shape (b, s, d), where:
36+
- b: batch dimension
37+
- s: sequence dimension (may differ)
38+
- d: feature dimension
39+
40+
Args:
41+
tensor1: First tensor with shape (b1, s1, d)
42+
tensor2: Second tensor with shape (b2, s2, d)
43+
concat_dim: Dimension to concatenate along, default is 0 (batch dimension)
44+
pad_dim: Dimension to pad along, default is 1 (sequence dimension)
45+
46+
Returns:
47+
Concatenated tensor, shape depends on concat_dim and pad_dim choices
48+
"""
49+
assert tensor1.dim() == tensor2.dim(), "Both tensors must have the same number of dimensions"
50+
assert concat_dim != pad_dim, "concat_dim and pad_dim cannot be the same"
51+
52+
len1, len2 = tensor1.shape[pad_dim], tensor2.shape[pad_dim]
53+
max_len = max(len1, len2)
54+
55+
# Calculate the position of pad_dim in the padding list
56+
# Padding format: from the last dimension, each pair represents (dim_n_left, dim_n_right, ..., dim_0_left, dim_0_right)
57+
ndim = tensor1.dim()
58+
padding = [0] * (2 * ndim)
59+
pad_right_idx = -2 * pad_dim - 1
60+
61+
if len1 < max_len:
62+
pad_len = max_len - len1
63+
padding[pad_right_idx] = pad_len
64+
tensor1 = F.pad(tensor1, padding, mode="constant", value=0)
65+
elif len2 < max_len:
66+
pad_len = max_len - len2
67+
padding[pad_right_idx] = pad_len
68+
tensor2 = F.pad(tensor2, padding, mode="constant", value=0)
69+
70+
# Concatenate along the specified dimension
71+
return torch.cat([tensor1, tensor2], dim=concat_dim)

diffsynth_engine/tokenizers/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# Modified from transformers.tokenization_utils_base
22
from typing import Dict, List, Union, overload
3+
from enum import Enum
34

45

56
TOKENIZER_CONFIG_FILE = "tokenizer_config.json"
67

78

9+
class PaddingStrategy(str, Enum):
10+
LONGEST = "longest"
11+
MAX_LENGTH = "max_length"
12+
13+
814
class BaseTokenizer:
915
SPECIAL_TOKENS_ATTRIBUTES = [
1016
"bos_token",

diffsynth_engine/tokenizers/qwen2.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Dict, List, Union, Optional
55
from tokenizers import Tokenizer as TokenizerFast, AddedToken
66

7-
from diffsynth_engine.tokenizers.base import BaseTokenizer, TOKENIZER_CONFIG_FILE
7+
from diffsynth_engine.tokenizers.base import BaseTokenizer, PaddingStrategy, TOKENIZER_CONFIG_FILE
88

99

1010
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
@@ -165,22 +165,28 @@ def __call__(
165165
texts: Union[str, List[str]],
166166
max_length: Optional[int] = None,
167167
padding_side: Optional[str] = None,
168+
padding_strategy: Union[PaddingStrategy, str] = "longest",
168169
**kwargs,
169170
) -> Dict[str, "torch.Tensor"]:
170171
"""
171172
Tokenize text and prepare for model inputs.
172173
173174
Args:
174-
text (`str`, `List[str]`, *optional*):
175+
texts (`str`, `List[str]`):
175176
The sequence or batch of sequences to be encoded.
176177
177178
max_length (`int`, *optional*):
178-
Each encoded sequence will be truncated or padded to max_length.
179+
Maximum length of the encoded sequences.
179180
180181
padding_side (`str`, *optional*):
181182
The side on which the padding should be applied. Should be selected between `"right"` and `"left"`.
182183
Defaults to `"right"`.
183184
185+
padding_strategy (`PaddingStrategy`, `str`, *optional*):
186+
If `"longest"`, will pad the sequences to the longest sequence in the batch.
187+
If `"max_length"`, will pad the sequences to the `max_length` argument.
188+
Defaults to `"longest"`.
189+
184190
Returns:
185191
`Dict[str, "torch.Tensor"]`: tensor dict compatible with model_input_names.
186192
"""
@@ -190,7 +196,9 @@ def __call__(
190196

191197
batch_ids = self.batch_encode(texts)
192198
ids_lens = [len(ids_) for ids_ in batch_ids]
193-
max_length = max_length if max_length is not None else min(max(ids_lens), self.model_max_length)
199+
max_length = max_length if max_length is not None else self.model_max_length
200+
if padding_strategy == PaddingStrategy.LONGEST:
201+
max_length = min(max(ids_lens), max_length)
194202
padding_side = padding_side if padding_side is not None else self.padding_side
195203

196204
encoded = torch.zeros(len(texts), max_length, dtype=torch.long)

0 commit comments

Comments
 (0)