Skip to content

Commit 7290eb6

Browse files
ciaranborCursor Assistant
andauthored
Align Qwen-Image with diffusers reference (#362)
Co-authored-by: Cursor Assistant <assistant@cursor.com>
1 parent bcdf4af commit 7290eb6

File tree

11 files changed

+60
-52
lines changed

11 files changed

+60
-52
lines changed

src/mflux/models/common/config/model_config.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ def __init__(
2323
requires_sigma_shift: bool | None,
2424
transformer_overrides: dict | None = None,
2525
text_encoder_overrides: dict | None = None,
26+
sigma_base_shift: float = 0.5,
27+
sigma_max_shift: float = 1.15,
28+
sigma_base_seq_len: int = 256,
29+
sigma_max_seq_len: int = 4096,
30+
sigma_shift_terminal: float | None = None,
2631
):
2732
self.aliases = aliases
2833
self.model_name = model_name
@@ -36,6 +41,11 @@ def __init__(
3641
self.priority = priority
3742
self.transformer_overrides = transformer_overrides or {}
3843
self.text_encoder_overrides = text_encoder_overrides or {}
44+
self.sigma_base_shift = sigma_base_shift
45+
self.sigma_max_shift = sigma_max_shift
46+
self.sigma_base_seq_len = sigma_base_seq_len
47+
self.sigma_max_seq_len = sigma_max_seq_len
48+
self.sigma_shift_terminal = sigma_shift_terminal
3949

4050
@staticmethod
4151
@lru_cache
@@ -411,7 +421,10 @@ def from_name(
411421
num_train_steps=None,
412422
max_sequence_length=None,
413423
supports_guidance=None,
414-
requires_sigma_shift=None,
424+
requires_sigma_shift=True,
425+
sigma_max_shift=0.9,
426+
sigma_max_seq_len=8192,
427+
sigma_shift_terminal=0.02,
415428
),
416429
"qwen-image-edit": ModelConfig(
417430
priority=16,
@@ -423,7 +436,10 @@ def from_name(
423436
num_train_steps=None,
424437
max_sequence_length=None,
425438
supports_guidance=None,
426-
requires_sigma_shift=None,
439+
requires_sigma_shift=True,
440+
sigma_max_shift=0.9,
441+
sigma_max_seq_len=8192,
442+
sigma_shift_terminal=0.02,
427443
),
428444
"fibo": ModelConfig(
429445
priority=17,

src/mflux/models/common/schedulers/linear_scheduler.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,26 @@ def timesteps(self) -> mx.array:
2424

2525
def _get_sigmas(self) -> mx.array:
2626
model_config = self.config.model_config
27-
sigmas = mx.linspace(
28-
1.0,
29-
1.0 / self.config.num_inference_steps,
30-
self.config.num_inference_steps,
31-
)
27+
num_steps = self.config.num_inference_steps
28+
sigmas = mx.linspace(1.0, 1.0 / num_steps, num_steps)
3229
sigmas = mx.array(sigmas).astype(mx.float32)
3330
sigmas = mx.concatenate([sigmas, mx.zeros(1)])
3431
if model_config.requires_sigma_shift:
35-
y1 = 0.5
36-
x1 = 256
37-
m = (1.15 - y1) / (4096 - x1)
38-
b = y1 - m * x1
32+
m = (model_config.sigma_max_shift - model_config.sigma_base_shift) / (
33+
model_config.sigma_max_seq_len - model_config.sigma_base_seq_len
34+
)
35+
b = model_config.sigma_base_shift - m * model_config.sigma_base_seq_len
3936
mu = m * self.config.width * self.config.height / 256 + b
4037
mu = mx.array(mu)
41-
shifted_sigmas = mx.exp(mu) / (mx.exp(mu) + (1 / sigmas - 1))
42-
shifted_sigmas[-1] = 0
43-
return shifted_sigmas
38+
39+
shifted = mx.exp(mu) / (mx.exp(mu) + (1 / sigmas[:-1] - 1))
40+
41+
if model_config.sigma_shift_terminal is not None:
42+
one_minus = 1.0 - shifted
43+
scale = one_minus[-1] / (1.0 - model_config.sigma_shift_terminal)
44+
shifted = 1.0 - (one_minus / scale)
45+
46+
return mx.concatenate([shifted, mx.zeros(1)])
4447
else:
4548
return sigmas
4649

src/mflux/models/qwen/model/qwen_text_encoder/qwen_attention.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,12 @@ def __call__(
6161
key_states = QwenAttention._repeat_kv(key_states, self.num_key_value_groups)
6262
value_states = QwenAttention._repeat_kv(value_states, self.num_key_value_groups)
6363

64-
attn_weights = mx.matmul(query_states, key_states.transpose(0, 1, 3, 2)) * self.scaling
65-
66-
if attention_mask is not None:
67-
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
68-
attn_weights = attn_weights + causal_mask
69-
70-
# Softmax and output
71-
attn_weights = mx.softmax(attn_weights.astype(mx.float32), axis=-1).astype(query_states.dtype)
72-
attn_output = mx.matmul(attn_weights, value_states)
64+
mask = attention_mask[:, :, :, : key_states.shape[-2]].astype(query_states.dtype) if attention_mask is not None else None
65+
attn_output = mx.fast.scaled_dot_product_attention(
66+
query_states, key_states, value_states,
67+
scale=self.scaling,
68+
mask=mask,
69+
)
7370
attn_output = attn_output.transpose(0, 2, 1, 3).reshape(bsz, q_len, self.hidden_size)
7471
attn_output = self.o_proj(attn_output)
7572
return attn_output

src/mflux/models/qwen/model/qwen_text_encoder/qwen_vision_attention.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@ def _rotate_half(self, x: mx.array) -> mx.array:
1818
return mx.concatenate([-x2, x1], axis=-1)
1919

2020
def _apply_rope(self, x: mx.array, cos: mx.array, sin: mx.array) -> mx.array:
21-
cos_expanded = mx.expand_dims(cos, axis=0)
22-
sin_expanded = mx.expand_dims(sin, axis=0)
21+
orig_dtype = x.dtype
22+
x = x.astype(mx.float32)
23+
cos_expanded = mx.expand_dims(cos, axis=0).astype(mx.float32)
24+
sin_expanded = mx.expand_dims(sin, axis=0).astype(mx.float32)
2325
rotated = (x * cos_expanded) + (self._rotate_half(x) * sin_expanded)
24-
return rotated
26+
return rotated.astype(orig_dtype)
2527

2628
def __call__(self, x: mx.array, position_embeddings=None, cu_seqlens=None) -> mx.array:
2729
seq_len, embed_dim = x.shape
@@ -42,41 +44,30 @@ def __call__(self, x: mx.array, position_embeddings=None, cu_seqlens=None) -> mx
4244
q = self._apply_rope(q, cos_emb, sin_emb)
4345
k = self._apply_rope(k, cos_emb, sin_emb)
4446

47+
scale = 1.0 / (self.head_dim**0.5)
48+
4549
# Process attention chunks if cu_seqlens is provided (windowed attention)
4650
if cu_seqlens is not None and len(cu_seqlens) > 2:
47-
# Split Q, K, V into chunks based on cu_seqlens
48-
# cu_seqlens is cumulative, so lengths[i] = cu_seqlens[i+1] - cu_seqlens[i]
4951
lengths = [int((cu_seqlens[i + 1] - cu_seqlens[i]).item()) for i in range(len(cu_seqlens) - 1)]
5052

51-
# Split tensors (q,k,v are [heads, seq, head_dim])
52-
q_chunks = []
53-
k_chunks = []
54-
v_chunks = []
53+
attn_outputs = []
5554
offset = 0
5655
for length in lengths:
57-
q_chunks.append(q[:, offset : offset + length, :])
58-
k_chunks.append(k[:, offset : offset + length, :])
59-
v_chunks.append(v[:, offset : offset + length, :])
56+
q_chunk = mx.expand_dims(q[:, offset : offset + length, :], axis=0)
57+
k_chunk = mx.expand_dims(k[:, offset : offset + length, :], axis=0)
58+
v_chunk = mx.expand_dims(v[:, offset : offset + length, :], axis=0)
6059
offset += length
60+
out = mx.fast.scaled_dot_product_attention(q_chunk, k_chunk, v_chunk, scale=scale)
61+
attn_outputs.append(out.squeeze(0))
6162

62-
# Process each chunk separately
63-
attn_outputs = []
64-
scale = 1.0 / (self.head_dim**0.5)
65-
for q_chunk, k_chunk, v_chunk in zip(q_chunks, k_chunks, v_chunks):
66-
# Compute attention for this chunk
67-
scores = mx.matmul(q_chunk, k_chunk.transpose(0, 2, 1)) * scale
68-
attn_weights = mx.softmax(scores, axis=-1)
69-
attn_chunk = mx.matmul(attn_weights, v_chunk) # [heads, chunk_len, head_dim]
70-
attn_outputs.append(attn_chunk)
71-
72-
# Concatenate chunks back together
7363
attn_output = mx.concatenate(attn_outputs, axis=1) # [heads, seq, head_dim]
7464
else:
7565
# Full attention (no chunking)
76-
scale = 1.0 / (self.head_dim**0.5)
77-
scores = mx.matmul(q, k.transpose(0, 2, 1)) * scale # [heads, seq, seq]
78-
attn_weights = mx.softmax(scores, axis=-1)
79-
attn_output = mx.matmul(attn_weights, v) # [heads, seq, head_dim]
66+
q_4d = mx.expand_dims(q, axis=0)
67+
k_4d = mx.expand_dims(k, axis=0)
68+
v_4d = mx.expand_dims(v, axis=0)
69+
attn_output = mx.fast.scaled_dot_product_attention(q_4d, k_4d, v_4d, scale=scale)
70+
attn_output = attn_output.squeeze(0) # [heads, seq, head_dim]
8071

8172
# Reshape and project
8273
attn_output = attn_output.transpose(1, 0, 2).reshape(seq_len, embed_dim) # [seq, embed_dim]

src/mflux/models/qwen/qwen_initializer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def init_edit(
5858
model.tokenizers["qwen_vl"] = QwenVisionLanguageTokenizer(
5959
processor=processor,
6060
max_length=1024,
61-
use_picture_prefix=True,
61+
use_picture_prefix=False,
6262
)
6363
model.qwen_vl_encoder = QwenVisionLanguageEncoder(encoder=model.text_encoder.encoder)
6464

src/mflux/models/qwen/weights/qwen_weight_definition.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ def get_tokenizers() -> List[TokenizerDefinition]:
4141
hf_subdir="tokenizer",
4242
tokenizer_class="Qwen2Tokenizer",
4343
encoder_class=LanguageTokenizer,
44-
max_length=1024,
44+
max_length=1058,
45+
padding="longest",
4546
template="<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n",
4647
download_patterns=["tokenizer/**", "added_tokens.json", "chat_template.jinja"],
4748
),
-4.79 KB
Loading
-24.4 KB
Loading
4.51 KB
Loading
28.1 KB
Loading

0 commit comments

Comments
 (0)