Skip to content

Commit d7f3241

Browse files
authored
qwen_image: propagate attention mask. (#11966)
1 parent 09a2e67 commit d7f3241

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,14 @@ def forward(
170170
joint_query = apply_rope1(joint_query, image_rotary_emb)
171171
joint_key = apply_rope1(joint_key, image_rotary_emb)
172172

173+
if encoder_hidden_states_mask is not None:
174+
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
175+
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
176+
else:
177+
attn_mask = None
178+
173179
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
174-
attention_mask, transformer_options=transformer_options,
180+
attn_mask, transformer_options=transformer_options,
175181
skip_reshape=True)
176182

177183
txt_attn_output = joint_hidden_states[:, :seq_txt, :]
@@ -430,6 +436,9 @@ def _forward(
430436
encoder_hidden_states = context
431437
encoder_hidden_states_mask = attention_mask
432438

439+
if encoder_hidden_states_mask is not None and not torch.is_floating_point(encoder_hidden_states_mask):
440+
encoder_hidden_states_mask = (encoder_hidden_states_mask - 1).to(x.dtype) * torch.finfo(x.dtype).max
441+
433442
hidden_states, img_ids, orig_shape = self.process_img(x)
434443
num_embeds = hidden_states.shape[1]
435444

comfy/model_base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,6 +1578,9 @@ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
15781578

15791579
def extra_conds(self, **kwargs):
15801580
out = super().extra_conds(**kwargs)
1581+
attention_mask = kwargs.get("attention_mask", None)
1582+
if attention_mask is not None:
1583+
out['attention_mask'] = comfy.conds.CONDRegular(attention_mask)
15811584
cross_attn = kwargs.get("cross_attn", None)
15821585
if cross_attn is not None:
15831586
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)

0 commit comments

Comments
 (0)