Skip to content

Commit e880a24

Browse files
naykunyiyixuxu
authored andcommitted
Qwen-Image (huggingface#12055)
* (feat): qwen-image integration * fix(qwen-image): - remove unused logics related to controlnet/ip-adapter * fix(qwen-image): - compatible with attention dispatcher - cond cache support * fix(qwen-image): - cond cache registry - attention backend argument - fix copies * fix(qwen-image): - remove local test * Update src/diffusers/models/transformers/transformer_qwenimage.py --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent e5d8001 commit e880a24

File tree

13 files changed

+2950
-0
lines changed

13 files changed

+2950
-0
lines changed

src/diffusers/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@
153153
"AutoencoderKLLTXVideo",
154154
"AutoencoderKLMagvit",
155155
"AutoencoderKLMochi",
156+
"AutoencoderKLQwenImage",
156157
"AutoencoderKLTemporalDecoder",
157158
"AutoencoderKLWan",
158159
"AutoencoderOobleck",
@@ -194,6 +195,7 @@
194195
"OmniGenTransformer2DModel",
195196
"PixArtTransformer2DModel",
196197
"PriorTransformer",
198+
"QwenImageTransformer2DModel",
197199
"SanaControlNetModel",
198200
"SanaTransformer2DModel",
199201
"SD3ControlNetModel",
@@ -443,6 +445,7 @@
443445
"PixArtAlphaPipeline",
444446
"PixArtSigmaPAGPipeline",
445447
"PixArtSigmaPipeline",
448+
"QwenImagePipeline",
446449
"ReduxImageEncoder",
447450
"SanaControlNetPipeline",
448451
"SanaPAGPipeline",
@@ -767,6 +770,7 @@
767770
AutoencoderKLLTXVideo,
768771
AutoencoderKLMagvit,
769772
AutoencoderKLMochi,
773+
AutoencoderKLQwenImage,
770774
AutoencoderKLTemporalDecoder,
771775
AutoencoderKLWan,
772776
AutoencoderOobleck,
@@ -808,6 +812,7 @@
808812
OmniGenTransformer2DModel,
809813
PixArtTransformer2DModel,
810814
PriorTransformer,
815+
QwenImageTransformer2DModel,
811816
SanaControlNetModel,
812817
SanaTransformer2DModel,
813818
SD3ControlNetModel,
@@ -1036,6 +1041,7 @@
10361041
PixArtAlphaPipeline,
10371042
PixArtSigmaPAGPipeline,
10381043
PixArtSigmaPipeline,
1044+
QwenImagePipeline,
10391045
ReduxImageEncoder,
10401046
SanaControlNetPipeline,
10411047
SanaPAGPipeline,

src/diffusers/hooks/_helpers.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
from dataclasses import dataclass
17+
from typing import Any, Callable, Dict, Type
18+
19+
20+
@dataclass
21+
class AttentionProcessorMetadata:
22+
skip_processor_output_fn: Callable[[Any], Any]
23+
24+
25+
@dataclass
26+
class TransformerBlockMetadata:
27+
return_hidden_states_index: int = None
28+
return_encoder_hidden_states_index: int = None
29+
30+
_cls: Type = None
31+
_cached_parameter_indices: Dict[str, int] = None
32+
33+
def _get_parameter_from_args_kwargs(self, identifier: str, args=(), kwargs=None):
34+
kwargs = kwargs or {}
35+
if identifier in kwargs:
36+
return kwargs[identifier]
37+
if self._cached_parameter_indices is not None:
38+
return args[self._cached_parameter_indices[identifier]]
39+
if self._cls is None:
40+
raise ValueError("Model class is not set for metadata.")
41+
parameters = list(inspect.signature(self._cls.forward).parameters.keys())
42+
parameters = parameters[1:] # skip `self`
43+
self._cached_parameter_indices = {param: i for i, param in enumerate(parameters)}
44+
if identifier not in self._cached_parameter_indices:
45+
raise ValueError(f"Parameter '{identifier}' not found in function signature but was requested.")
46+
index = self._cached_parameter_indices[identifier]
47+
if index >= len(args):
48+
raise ValueError(f"Expected {index} arguments but got {len(args)}.")
49+
return args[index]
50+
51+
52+
class AttentionProcessorRegistry:
53+
_registry = {}
54+
# TODO(aryan): this is only required for the time being because we need to do the registrations
55+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
56+
# import errors because of the models imported in this file.
57+
_is_registered = False
58+
59+
@classmethod
60+
def register(cls, model_class: Type, metadata: AttentionProcessorMetadata):
61+
cls._register()
62+
cls._registry[model_class] = metadata
63+
64+
@classmethod
65+
def get(cls, model_class: Type) -> AttentionProcessorMetadata:
66+
cls._register()
67+
if model_class not in cls._registry:
68+
raise ValueError(f"Model class {model_class} not registered.")
69+
return cls._registry[model_class]
70+
71+
@classmethod
72+
def _register(cls):
73+
if cls._is_registered:
74+
return
75+
cls._is_registered = True
76+
_register_attention_processors_metadata()
77+
78+
79+
class TransformerBlockRegistry:
80+
_registry = {}
81+
# TODO(aryan): this is only required for the time being because we need to do the registrations
82+
# for classes. If we do it eagerly, i.e. call the functions in global scope, we will get circular
83+
# import errors because of the models imported in this file.
84+
_is_registered = False
85+
86+
@classmethod
87+
def register(cls, model_class: Type, metadata: TransformerBlockMetadata):
88+
cls._register()
89+
metadata._cls = model_class
90+
cls._registry[model_class] = metadata
91+
92+
@classmethod
93+
def get(cls, model_class: Type) -> TransformerBlockMetadata:
94+
cls._register()
95+
if model_class not in cls._registry:
96+
raise ValueError(f"Model class {model_class} not registered.")
97+
return cls._registry[model_class]
98+
99+
@classmethod
100+
def _register(cls):
101+
if cls._is_registered:
102+
return
103+
cls._is_registered = True
104+
_register_transformer_blocks_metadata()
105+
106+
107+
def _register_attention_processors_metadata():
108+
from ..models.attention_processor import AttnProcessor2_0
109+
from ..models.transformers.transformer_cogview4 import CogView4AttnProcessor
110+
from ..models.transformers.transformer_flux import FluxAttnProcessor
111+
from ..models.transformers.transformer_wan import WanAttnProcessor2_0
112+
113+
# AttnProcessor2_0
114+
AttentionProcessorRegistry.register(
115+
model_class=AttnProcessor2_0,
116+
metadata=AttentionProcessorMetadata(
117+
skip_processor_output_fn=_skip_proc_output_fn_Attention_AttnProcessor2_0,
118+
),
119+
)
120+
121+
# CogView4AttnProcessor
122+
AttentionProcessorRegistry.register(
123+
model_class=CogView4AttnProcessor,
124+
metadata=AttentionProcessorMetadata(
125+
skip_processor_output_fn=_skip_proc_output_fn_Attention_CogView4AttnProcessor,
126+
),
127+
)
128+
129+
# WanAttnProcessor2_0
130+
AttentionProcessorRegistry.register(
131+
model_class=WanAttnProcessor2_0,
132+
metadata=AttentionProcessorMetadata(
133+
skip_processor_output_fn=_skip_proc_output_fn_Attention_WanAttnProcessor2_0,
134+
),
135+
)
136+
# FluxAttnProcessor
137+
AttentionProcessorRegistry.register(
138+
model_class=FluxAttnProcessor,
139+
metadata=AttentionProcessorMetadata(skip_processor_output_fn=_skip_proc_output_fn_Attention_FluxAttnProcessor),
140+
)
141+
142+
143+
def _register_transformer_blocks_metadata():
144+
from ..models.attention import BasicTransformerBlock
145+
from ..models.transformers.cogvideox_transformer_3d import CogVideoXBlock
146+
from ..models.transformers.transformer_cogview4 import CogView4TransformerBlock
147+
from ..models.transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
148+
from ..models.transformers.transformer_hunyuan_video import (
149+
HunyuanVideoSingleTransformerBlock,
150+
HunyuanVideoTokenReplaceSingleTransformerBlock,
151+
HunyuanVideoTokenReplaceTransformerBlock,
152+
HunyuanVideoTransformerBlock,
153+
)
154+
from ..models.transformers.transformer_ltx import LTXVideoTransformerBlock
155+
from ..models.transformers.transformer_mochi import MochiTransformerBlock
156+
from ..models.transformers.transformer_qwenimage import QwenImageTransformerBlock
157+
from ..models.transformers.transformer_wan import WanTransformerBlock
158+
159+
# BasicTransformerBlock
160+
TransformerBlockRegistry.register(
161+
model_class=BasicTransformerBlock,
162+
metadata=TransformerBlockMetadata(
163+
return_hidden_states_index=0,
164+
return_encoder_hidden_states_index=None,
165+
),
166+
)
167+
168+
# CogVideoX
169+
TransformerBlockRegistry.register(
170+
model_class=CogVideoXBlock,
171+
metadata=TransformerBlockMetadata(
172+
return_hidden_states_index=0,
173+
return_encoder_hidden_states_index=1,
174+
),
175+
)
176+
177+
# CogView4
178+
TransformerBlockRegistry.register(
179+
model_class=CogView4TransformerBlock,
180+
metadata=TransformerBlockMetadata(
181+
return_hidden_states_index=0,
182+
return_encoder_hidden_states_index=1,
183+
),
184+
)
185+
186+
# Flux
187+
TransformerBlockRegistry.register(
188+
model_class=FluxTransformerBlock,
189+
metadata=TransformerBlockMetadata(
190+
return_hidden_states_index=1,
191+
return_encoder_hidden_states_index=0,
192+
),
193+
)
194+
TransformerBlockRegistry.register(
195+
model_class=FluxSingleTransformerBlock,
196+
metadata=TransformerBlockMetadata(
197+
return_hidden_states_index=1,
198+
return_encoder_hidden_states_index=0,
199+
),
200+
)
201+
202+
# HunyuanVideo
203+
TransformerBlockRegistry.register(
204+
model_class=HunyuanVideoTransformerBlock,
205+
metadata=TransformerBlockMetadata(
206+
return_hidden_states_index=0,
207+
return_encoder_hidden_states_index=1,
208+
),
209+
)
210+
TransformerBlockRegistry.register(
211+
model_class=HunyuanVideoSingleTransformerBlock,
212+
metadata=TransformerBlockMetadata(
213+
return_hidden_states_index=0,
214+
return_encoder_hidden_states_index=1,
215+
),
216+
)
217+
TransformerBlockRegistry.register(
218+
model_class=HunyuanVideoTokenReplaceTransformerBlock,
219+
metadata=TransformerBlockMetadata(
220+
return_hidden_states_index=0,
221+
return_encoder_hidden_states_index=1,
222+
),
223+
)
224+
TransformerBlockRegistry.register(
225+
model_class=HunyuanVideoTokenReplaceSingleTransformerBlock,
226+
metadata=TransformerBlockMetadata(
227+
return_hidden_states_index=0,
228+
return_encoder_hidden_states_index=1,
229+
),
230+
)
231+
232+
# LTXVideo
233+
TransformerBlockRegistry.register(
234+
model_class=LTXVideoTransformerBlock,
235+
metadata=TransformerBlockMetadata(
236+
return_hidden_states_index=0,
237+
return_encoder_hidden_states_index=None,
238+
),
239+
)
240+
241+
# Mochi
242+
TransformerBlockRegistry.register(
243+
model_class=MochiTransformerBlock,
244+
metadata=TransformerBlockMetadata(
245+
return_hidden_states_index=0,
246+
return_encoder_hidden_states_index=1,
247+
),
248+
)
249+
250+
# Wan
251+
TransformerBlockRegistry.register(
252+
model_class=WanTransformerBlock,
253+
metadata=TransformerBlockMetadata(
254+
return_hidden_states_index=0,
255+
return_encoder_hidden_states_index=None,
256+
),
257+
)
258+
259+
# QwenImage
260+
TransformerBlockRegistry.register(
261+
model_class=QwenImageTransformerBlock,
262+
metadata=TransformerBlockMetadata(
263+
return_hidden_states_index=1,
264+
return_encoder_hidden_states_index=0,
265+
),
266+
)
267+
268+
269+
# fmt: off
270+
def _skip_attention___ret___hidden_states(self, *args, **kwargs):
271+
hidden_states = kwargs.get("hidden_states", None)
272+
if hidden_states is None and len(args) > 0:
273+
hidden_states = args[0]
274+
return hidden_states
275+
276+
277+
def _skip_attention___ret___hidden_states___encoder_hidden_states(self, *args, **kwargs):
278+
hidden_states = kwargs.get("hidden_states", None)
279+
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
280+
if hidden_states is None and len(args) > 0:
281+
hidden_states = args[0]
282+
if encoder_hidden_states is None and len(args) > 1:
283+
encoder_hidden_states = args[1]
284+
return hidden_states, encoder_hidden_states
285+
286+
287+
_skip_proc_output_fn_Attention_AttnProcessor2_0 = _skip_attention___ret___hidden_states
288+
_skip_proc_output_fn_Attention_CogView4AttnProcessor = _skip_attention___ret___hidden_states___encoder_hidden_states
289+
_skip_proc_output_fn_Attention_WanAttnProcessor2_0 = _skip_attention___ret___hidden_states
290+
# not sure what this is yet.
291+
_skip_proc_output_fn_Attention_FluxAttnProcessor = _skip_attention___ret___hidden_states
292+
# fmt: on

src/diffusers/models/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
_import_structure["autoencoders.autoencoder_kl_ltx"] = ["AutoencoderKLLTXVideo"]
3838
_import_structure["autoencoders.autoencoder_kl_magvit"] = ["AutoencoderKLMagvit"]
3939
_import_structure["autoencoders.autoencoder_kl_mochi"] = ["AutoencoderKLMochi"]
40+
_import_structure["autoencoders.autoencoder_kl_qwenimage"] = ["AutoencoderKLQwenImage"]
4041
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
4142
_import_structure["autoencoders.autoencoder_kl_wan"] = ["AutoencoderKLWan"]
4243
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
@@ -87,6 +88,7 @@
8788
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
8889
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
8990
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
91+
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
9092
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
9193
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
9294
_import_structure["transformers.transformer_wan"] = ["WanTransformer3DModel"]
@@ -123,6 +125,7 @@
123125
AutoencoderKLLTXVideo,
124126
AutoencoderKLMagvit,
125127
AutoencoderKLMochi,
128+
AutoencoderKLQwenImage,
126129
AutoencoderKLTemporalDecoder,
127130
AutoencoderKLWan,
128131
AutoencoderOobleck,
@@ -174,6 +177,7 @@
174177
OmniGenTransformer2DModel,
175178
PixArtTransformer2DModel,
176179
PriorTransformer,
180+
QwenImageTransformer2DModel,
177181
SanaTransformer2DModel,
178182
SD3Transformer2DModel,
179183
StableAudioDiTModel,

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
99
from .autoencoder_kl_magvit import AutoencoderKLMagvit
1010
from .autoencoder_kl_mochi import AutoencoderKLMochi
11+
from .autoencoder_kl_qwenimage import AutoencoderKLQwenImage
1112
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
1213
from .autoencoder_kl_wan import AutoencoderKLWan
1314
from .autoencoder_oobleck import AutoencoderOobleck

0 commit comments

Comments
 (0)