Skip to content

Commit b7eecf0

Browse files
committed
support qwen-image-cn-union
1 parent 8c628eb commit b7eecf0

File tree

7 files changed

+1254
-0
lines changed

7 files changed

+1254
-0
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,8 @@
218218
"OmniGenTransformer2DModel",
219219
"PixArtTransformer2DModel",
220220
"PriorTransformer",
221+
"QwenImageControlNetModel",
222+
"QwenImageMultiControlNetModel",
221223
"QwenImageTransformer2DModel",
222224
"SanaControlNetModel",
223225
"SanaTransformer2DModel",
@@ -885,6 +887,8 @@
885887
OmniGenTransformer2DModel,
886888
PixArtTransformer2DModel,
887889
PriorTransformer,
890+
QwenImageControlNetModel,
891+
QwenImageMultiControlNetModel,
888892
QwenImageTransformer2DModel,
889893
SanaControlNetModel,
890894
SanaTransformer2DModel,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"HunyuanDiT2DControlNetModel",
5353
"HunyuanDiT2DMultiControlNetModel",
5454
]
55+
_import_structure["controlnets.controlnet_qwenimage"] = ["QwenImageControlNetModel", "QwenImageMultiControlNetModel"]
5556
_import_structure["controlnets.controlnet_sana"] = ["SanaControlNetModel"]
5657
_import_structure["controlnets.controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
5758
_import_structure["controlnets.controlnet_sparsectrl"] = ["SparseControlNetModel"]

src/diffusers/models/controlnets/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
HunyuanDiT2DControlNetModel,
1010
HunyuanDiT2DMultiControlNetModel,
1111
)
12+
from .controlnet_qwenimage import QwenImageControlNetModel, QwenImageMultiControlNetModel
1213
from .controlnet_sana import SanaControlNetModel
1314
from .controlnet_sd3 import SD3ControlNetModel, SD3ControlNetOutput, SD3MultiControlNetModel
1415
from .controlnet_sparsectrl import (
Lines changed: 352 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,352 @@
1+
# Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from dataclasses import dataclass
16+
from typing import Any, Dict, List, Optional, Tuple, Union
17+
18+
import torch
19+
import torch.nn as nn
20+
21+
from ...configuration_utils import ConfigMixin, register_to_config
22+
from ...loaders import PeftAdapterMixin, FromOriginalModelMixin
23+
from ...utils import USE_PEFT_BACKEND, BaseOutput, logging, scale_lora_layers, unscale_lora_layers
24+
from ..attention_processor import AttentionProcessor
25+
from ..cache_utils import CacheMixin
26+
from ..controlnets.controlnet import zero_module
27+
from ..modeling_outputs import Transformer2DModelOutput
28+
from ..modeling_utils import ModelMixin
29+
from ..transformers.transformer_qwenimage import QwenImageTransformerBlock, QwenTimestepProjEmbeddings, QwenEmbedRope, RMSNorm
30+
31+
32+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33+
34+
35+
@dataclass
36+
class QwenImageControlNetOutput(BaseOutput):
37+
controlnet_block_samples: Tuple[torch.Tensor]
38+
39+
40+
class QwenImageControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
41+
_supports_gradient_checkpointing = True
42+
43+
@register_to_config
44+
def __init__(
45+
self,
46+
patch_size: int = 2,
47+
in_channels: int = 64,
48+
out_channels: Optional[int] = 16,
49+
num_layers: int = 60,
50+
attention_head_dim: int = 128,
51+
num_attention_heads: int = 24,
52+
joint_attention_dim: int = 3584,
53+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
54+
extra_condition_channels: int = 0, # for controlnet-inpainting
55+
):
56+
super().__init__()
57+
self.out_channels = out_channels or in_channels
58+
self.inner_dim = num_attention_heads * attention_head_dim
59+
60+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
61+
62+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
63+
64+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
65+
66+
self.img_in = nn.Linear(in_channels, self.inner_dim)
67+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
68+
69+
self.transformer_blocks = nn.ModuleList(
70+
[
71+
QwenImageTransformerBlock(
72+
dim=self.inner_dim,
73+
num_attention_heads=num_attention_heads,
74+
attention_head_dim=attention_head_dim,
75+
)
76+
for _ in range(num_layers)
77+
]
78+
)
79+
80+
# controlnet_blocks
81+
self.controlnet_blocks = nn.ModuleList([])
82+
for _ in range(len(self.transformer_blocks)):
83+
self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
84+
self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels + extra_condition_channels, self.inner_dim))
85+
86+
self.gradient_checkpointing = False
87+
88+
@property
89+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
90+
def attn_processors(self):
91+
r"""
92+
Returns:
93+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
94+
indexed by its weight name.
95+
"""
96+
# set recursively
97+
processors = {}
98+
99+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
100+
if hasattr(module, "get_processor"):
101+
processors[f"{name}.processor"] = module.get_processor()
102+
103+
for sub_name, child in module.named_children():
104+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
105+
106+
return processors
107+
108+
for name, module in self.named_children():
109+
fn_recursive_add_processors(name, module, processors)
110+
111+
return processors
112+
113+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
114+
def set_attn_processor(self, processor):
115+
r"""
116+
Sets the attention processor to use to compute attention.
117+
118+
Parameters:
119+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
120+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
121+
for **all** `Attention` layers.
122+
123+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
124+
processor. This is strongly recommended when setting trainable attention processors.
125+
126+
"""
127+
count = len(self.attn_processors.keys())
128+
129+
if isinstance(processor, dict) and len(processor) != count:
130+
raise ValueError(
131+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
132+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
133+
)
134+
135+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
136+
if hasattr(module, "set_processor"):
137+
if not isinstance(processor, dict):
138+
module.set_processor(processor)
139+
else:
140+
module.set_processor(processor.pop(f"{name}.processor"))
141+
142+
for sub_name, child in module.named_children():
143+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
144+
145+
for name, module in self.named_children():
146+
fn_recursive_attn_processor(name, module, processor)
147+
148+
@classmethod
149+
def from_transformer(
150+
cls,
151+
transformer,
152+
num_layers: int = 5,
153+
attention_head_dim: int = 128,
154+
num_attention_heads: int = 24,
155+
load_weights_from_transformer=True,
156+
extra_condition_channels: int = 0,
157+
):
158+
config = dict(transformer.config)
159+
config["num_layers"] = num_layers
160+
config["attention_head_dim"] = attention_head_dim
161+
config["num_attention_heads"] = num_attention_heads
162+
config["extra_condition_channels"] = extra_condition_channels
163+
164+
controlnet = cls.from_config(config)
165+
166+
if load_weights_from_transformer:
167+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
168+
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
169+
controlnet.img_in.load_state_dict(transformer.img_in.state_dict())
170+
controlnet.txt_in.load_state_dict(transformer.txt_in.state_dict())
171+
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
172+
controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
173+
174+
return controlnet
175+
176+
def forward(
177+
self,
178+
hidden_states: torch.Tensor,
179+
controlnet_cond: torch.Tensor,
180+
conditioning_scale: float = 1.0,
181+
encoder_hidden_states: torch.Tensor = None,
182+
encoder_hidden_states_mask: torch.Tensor = None,
183+
timestep: torch.LongTensor = None,
184+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
185+
txt_seq_lens: Optional[List[int]] = None,
186+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
187+
return_dict: bool = True,
188+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
189+
"""
190+
The [`FluxTransformer2DModel`] forward method.
191+
192+
Args:
193+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
194+
Input `hidden_states`.
195+
controlnet_cond (`torch.Tensor`):
196+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
197+
conditioning_scale (`float`, defaults to `1.0`):
198+
The scale factor for ControlNet outputs.
199+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
200+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
201+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
202+
from the embeddings of input conditions.
203+
timestep ( `torch.LongTensor`):
204+
Used to indicate denoising step.
205+
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
206+
A list of tensors that if specified are added to the residuals of transformer blocks.
207+
joint_attention_kwargs (`dict`, *optional*):
208+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
209+
`self.processor` in
210+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
211+
return_dict (`bool`, *optional*, defaults to `True`):
212+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
213+
tuple.
214+
215+
Returns:
216+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
217+
`tuple` where the first element is the sample tensor.
218+
"""
219+
if joint_attention_kwargs is not None:
220+
joint_attention_kwargs = joint_attention_kwargs.copy()
221+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
222+
else:
223+
lora_scale = 1.0
224+
225+
if USE_PEFT_BACKEND:
226+
# weight the lora layers by setting `lora_scale` for each PEFT layer
227+
scale_lora_layers(self, lora_scale)
228+
else:
229+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
230+
logger.warning(
231+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
232+
)
233+
hidden_states = self.img_in(hidden_states)
234+
235+
# add
236+
hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
237+
238+
temb = self.time_text_embed(timestep, hidden_states)
239+
240+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
241+
242+
timestep = timestep.to(hidden_states.dtype)
243+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
244+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
245+
246+
block_samples = ()
247+
for index_block, block in enumerate(self.transformer_blocks):
248+
if torch.is_grad_enabled() and self.gradient_checkpointing:
249+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
250+
block,
251+
hidden_states,
252+
encoder_hidden_states,
253+
encoder_hidden_states_mask,
254+
temb,
255+
image_rotary_emb,
256+
)
257+
258+
else:
259+
encoder_hidden_states, hidden_states = block(
260+
hidden_states=hidden_states,
261+
encoder_hidden_states=encoder_hidden_states,
262+
encoder_hidden_states_mask=encoder_hidden_states_mask,
263+
temb=temb,
264+
image_rotary_emb=image_rotary_emb,
265+
joint_attention_kwargs=joint_attention_kwargs,
266+
)
267+
block_samples = block_samples + (hidden_states,)
268+
269+
# controlnet block
270+
controlnet_block_samples = ()
271+
for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
272+
block_sample = controlnet_block(block_sample)
273+
controlnet_block_samples = controlnet_block_samples + (block_sample,)
274+
275+
# scaling
276+
controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
277+
controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
278+
279+
if USE_PEFT_BACKEND:
280+
# remove `lora_scale` from each PEFT layer
281+
unscale_lora_layers(self, lora_scale)
282+
283+
if not return_dict:
284+
return (controlnet_block_samples)
285+
286+
return QwenImageControlNetOutput(
287+
controlnet_block_samples=controlnet_block_samples,
288+
)
289+
290+
291+
class QwenImageMultiControlNetModel(ModelMixin):
292+
r"""
293+
`QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
294+
295+
This module is a wrapper for multiple instances of the `QwenImageControlNetModel`. The `forward()` API is designed to be
296+
compatible with `QwenImageControlNetModel`.
297+
298+
Args:
299+
controlnets (`List[QwenImageControlNetModel]`):
300+
Provides additional conditioning to the unet during the denoising process. You must set multiple
301+
`QwenImageControlNetModel` as a list.
302+
"""
303+
304+
def __init__(self, controlnets):
305+
super().__init__()
306+
self.nets = nn.ModuleList(controlnets)
307+
308+
def forward(
309+
self,
310+
hidden_states: torch.FloatTensor,
311+
controlnet_cond: List[torch.tensor],
312+
conditioning_scale: List[float],
313+
encoder_hidden_states: torch.Tensor = None,
314+
encoder_hidden_states_mask: torch.Tensor = None,
315+
timestep: torch.LongTensor = None,
316+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
317+
txt_seq_lens: Optional[List[int]] = None,
318+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
319+
return_dict: bool = True,
320+
) -> Union[QwenImageControlNetOutput, Tuple]:
321+
# ControlNet-Union with multiple conditions
322+
# only load one ControlNet for saving memories
323+
if len(self.nets) == 1:
324+
controlnet = self.nets[0]
325+
326+
for i, (image, scale) in enumerate(zip(controlnet_cond, conditioning_scale)):
327+
block_samples = controlnet(
328+
hidden_states=hidden_states,
329+
controlnet_cond=image,
330+
conditioning_scale=scale,
331+
encoder_hidden_states=encoder_hidden_states,
332+
encoder_hidden_states_mask=encoder_hidden_states_mask,
333+
timestep=timestep,
334+
img_shapes=img_shapes,
335+
txt_seq_lens=txt_seq_lens,
336+
joint_attention_kwargs=joint_attention_kwargs,
337+
return_dict=return_dict,
338+
)
339+
340+
# merge samples
341+
if i == 0:
342+
control_block_samples = block_samples
343+
else:
344+
if block_samples is not None and control_block_samples is not None:
345+
control_block_samples = [
346+
control_block_sample + block_sample
347+
for control_block_sample, block_sample in zip(control_block_samples, block_samples)
348+
]
349+
else:
350+
raise ValueError("QwenImageMultiControlNetModel only supports controlnet-union now.")
351+
352+
return control_block_samples

0 commit comments

Comments
 (0)