Skip to content

Commit d322b51

Browse files
committed
fix(qwen-image):
- cond cache registry - attention backend argument - fix copies
1 parent f644e9b commit d322b51

File tree

4 files changed

+62
-1
lines changed

4 files changed

+62
-1
lines changed

src/diffusers/hooks/_helpers.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def _register_transformer_blocks_metadata():
138138
)
139139
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
140140
from ..models.transformers.transformer_mochi import MochiTransformerBlock
141+
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
141142
from ..models.transformers.transformer_wan import WanTransformerBlock
142143

143144
# BasicTransformerBlock
@@ -240,6 +241,15 @@ def _register_transformer_blocks_metadata():
240241
),
241242
)
242243

244+
# QwenImage
245+
TransformerBlockRegistry.register(
246+
model_class=QwenImageTransformerBlock,
247+
metadata=TransformerBlockMetadata(
248+
return_hidden_states_index=1,
249+
return_encoder_hidden_states_index=0,
250+
),
251+
)
252+
243253

244254
# fmt: off
245255
def _skip_attention___ret___hidden_states(self, *args, **kwargs):

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,13 @@ def __call__(
310310

311311
# Compute joint attention
312312
joint_hidden_states = dispatch_attention_fn(
313-
joint_query, joint_key, joint_value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
313+
joint_query,
314+
joint_key,
315+
joint_value,
316+
attn_mask=attention_mask,
317+
dropout_p=0.0,
318+
is_causal=False,
319+
backend=self._attention_backend,
314320
)
315321

316322
# Reshape back

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,21 @@ def from_pretrained(cls, *args, **kwargs):
423423
requires_backends(cls, ["torch"])
424424

425425

426+
class AutoencoderKLQwenImage(metaclass=DummyObject):
427+
_backends = ["torch"]
428+
429+
def __init__(self, *args, **kwargs):
430+
requires_backends(self, ["torch"])
431+
432+
@classmethod
433+
def from_config(cls, *args, **kwargs):
434+
requires_backends(cls, ["torch"])
435+
436+
@classmethod
437+
def from_pretrained(cls, *args, **kwargs):
438+
requires_backends(cls, ["torch"])
439+
440+
426441
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
427442
_backends = ["torch"]
428443

@@ -1038,6 +1053,21 @@ def from_pretrained(cls, *args, **kwargs):
10381053
requires_backends(cls, ["torch"])
10391054

10401055

1056+
class QwenImageTransformer2DModel(metaclass=DummyObject):
1057+
_backends = ["torch"]
1058+
1059+
def __init__(self, *args, **kwargs):
1060+
requires_backends(self, ["torch"])
1061+
1062+
@classmethod
1063+
def from_config(cls, *args, **kwargs):
1064+
requires_backends(cls, ["torch"])
1065+
1066+
@classmethod
1067+
def from_pretrained(cls, *args, **kwargs):
1068+
requires_backends(cls, ["torch"])
1069+
1070+
10411071
class SanaControlNetModel(metaclass=DummyObject):
10421072
_backends = ["torch"]
10431073

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,21 @@ def from_pretrained(cls, *args, **kwargs):
16821682
requires_backends(cls, ["torch", "transformers"])
16831683

16841684

1685+
class QwenImagePipeline(metaclass=DummyObject):
1686+
_backends = ["torch", "transformers"]
1687+
1688+
def __init__(self, *args, **kwargs):
1689+
requires_backends(self, ["torch", "transformers"])
1690+
1691+
@classmethod
1692+
def from_config(cls, *args, **kwargs):
1693+
requires_backends(cls, ["torch", "transformers"])
1694+
1695+
@classmethod
1696+
def from_pretrained(cls, *args, **kwargs):
1697+
requires_backends(cls, ["torch", "transformers"])
1698+
1699+
16851700
class ReduxImageEncoder(metaclass=DummyObject):
16861701
_backends = ["torch", "transformers"]
16871702

0 commit comments

Comments
 (0)