1+ # Copyright 2025 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ from dataclasses import dataclass
16+ from typing import Any , Dict , List , Optional , Tuple , Union
17+
18+ import torch
19+ import torch .nn as nn
20+
21+ from ...configuration_utils import ConfigMixin , register_to_config
22+ from ...loaders import PeftAdapterMixin , FromOriginalModelMixin
23+ from ...utils import USE_PEFT_BACKEND , BaseOutput , logging , scale_lora_layers , unscale_lora_layers
24+ from ..attention_processor import AttentionProcessor
25+ from ..cache_utils import CacheMixin
26+ from ..controlnets .controlnet import zero_module
27+ from ..modeling_outputs import Transformer2DModelOutput
28+ from ..modeling_utils import ModelMixin
29+ from ..transformers .transformer_qwenimage import QwenImageTransformerBlock , QwenTimestepProjEmbeddings , QwenEmbedRope , RMSNorm
30+
31+
32+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
33+
34+
35+ @dataclass
36+ class QwenImageControlNetOutput (BaseOutput ):
37+ controlnet_block_samples : Tuple [torch .Tensor ]
38+
39+
40+ class QwenImageControlNetModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin , CacheMixin ):
41+ _supports_gradient_checkpointing = True
42+
43+ @register_to_config
44+ def __init__ (
45+ self ,
46+ patch_size : int = 2 ,
47+ in_channels : int = 64 ,
48+ out_channels : Optional [int ] = 16 ,
49+ num_layers : int = 60 ,
50+ attention_head_dim : int = 128 ,
51+ num_attention_heads : int = 24 ,
52+ joint_attention_dim : int = 3584 ,
53+ axes_dims_rope : Tuple [int , int , int ] = (16 , 56 , 56 ),
54+ extra_condition_channels : int = 0 , # for controlnet-inpainting
55+ ):
56+ super ().__init__ ()
57+ self .out_channels = out_channels or in_channels
58+ self .inner_dim = num_attention_heads * attention_head_dim
59+
60+ self .pos_embed = QwenEmbedRope (theta = 10000 , axes_dim = list (axes_dims_rope ), scale_rope = True )
61+
62+ self .time_text_embed = QwenTimestepProjEmbeddings (embedding_dim = self .inner_dim )
63+
64+ self .txt_norm = RMSNorm (joint_attention_dim , eps = 1e-6 )
65+
66+ self .img_in = nn .Linear (in_channels , self .inner_dim )
67+ self .txt_in = nn .Linear (joint_attention_dim , self .inner_dim )
68+
69+ self .transformer_blocks = nn .ModuleList (
70+ [
71+ QwenImageTransformerBlock (
72+ dim = self .inner_dim ,
73+ num_attention_heads = num_attention_heads ,
74+ attention_head_dim = attention_head_dim ,
75+ )
76+ for _ in range (num_layers )
77+ ]
78+ )
79+
80+ # controlnet_blocks
81+ self .controlnet_blocks = nn .ModuleList ([])
82+ for _ in range (len (self .transformer_blocks )):
83+ self .controlnet_blocks .append (zero_module (nn .Linear (self .inner_dim , self .inner_dim )))
84+ self .controlnet_x_embedder = zero_module (torch .nn .Linear (in_channels + extra_condition_channels , self .inner_dim ))
85+
86+ self .gradient_checkpointing = False
87+
88+ @property
89+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
90+ def attn_processors (self ):
91+ r"""
92+ Returns:
93+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
94+ indexed by its weight name.
95+ """
96+ # set recursively
97+ processors = {}
98+
99+ def fn_recursive_add_processors (name : str , module : torch .nn .Module , processors : Dict [str , AttentionProcessor ]):
100+ if hasattr (module , "get_processor" ):
101+ processors [f"{ name } .processor" ] = module .get_processor ()
102+
103+ for sub_name , child in module .named_children ():
104+ fn_recursive_add_processors (f"{ name } .{ sub_name } " , child , processors )
105+
106+ return processors
107+
108+ for name , module in self .named_children ():
109+ fn_recursive_add_processors (name , module , processors )
110+
111+ return processors
112+
113+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
114+ def set_attn_processor (self , processor ):
115+ r"""
116+ Sets the attention processor to use to compute attention.
117+
118+ Parameters:
119+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
120+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
121+ for **all** `Attention` layers.
122+
123+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
124+ processor. This is strongly recommended when setting trainable attention processors.
125+
126+ """
127+ count = len (self .attn_processors .keys ())
128+
129+ if isinstance (processor , dict ) and len (processor ) != count :
130+ raise ValueError (
131+ f"A dict of processors was passed, but the number of processors { len (processor )} does not match the"
132+ f" number of attention layers: { count } . Please make sure to pass { count } processor classes."
133+ )
134+
135+ def fn_recursive_attn_processor (name : str , module : torch .nn .Module , processor ):
136+ if hasattr (module , "set_processor" ):
137+ if not isinstance (processor , dict ):
138+ module .set_processor (processor )
139+ else :
140+ module .set_processor (processor .pop (f"{ name } .processor" ))
141+
142+ for sub_name , child in module .named_children ():
143+ fn_recursive_attn_processor (f"{ name } .{ sub_name } " , child , processor )
144+
145+ for name , module in self .named_children ():
146+ fn_recursive_attn_processor (name , module , processor )
147+
148+ @classmethod
149+ def from_transformer (
150+ cls ,
151+ transformer ,
152+ num_layers : int = 5 ,
153+ attention_head_dim : int = 128 ,
154+ num_attention_heads : int = 24 ,
155+ load_weights_from_transformer = True ,
156+ extra_condition_channels : int = 0 ,
157+ ):
158+ config = dict (transformer .config )
159+ config ["num_layers" ] = num_layers
160+ config ["attention_head_dim" ] = attention_head_dim
161+ config ["num_attention_heads" ] = num_attention_heads
162+ config ["extra_condition_channels" ] = extra_condition_channels
163+
164+ controlnet = cls .from_config (config )
165+
166+ if load_weights_from_transformer :
167+ controlnet .pos_embed .load_state_dict (transformer .pos_embed .state_dict ())
168+ controlnet .time_text_embed .load_state_dict (transformer .time_text_embed .state_dict ())
169+ controlnet .img_in .load_state_dict (transformer .img_in .state_dict ())
170+ controlnet .txt_in .load_state_dict (transformer .txt_in .state_dict ())
171+ controlnet .transformer_blocks .load_state_dict (transformer .transformer_blocks .state_dict (), strict = False )
172+ controlnet .controlnet_x_embedder = zero_module (controlnet .controlnet_x_embedder )
173+
174+ return controlnet
175+
176+ def forward (
177+ self ,
178+ hidden_states : torch .Tensor ,
179+ controlnet_cond : torch .Tensor ,
180+ conditioning_scale : float = 1.0 ,
181+ encoder_hidden_states : torch .Tensor = None ,
182+ encoder_hidden_states_mask : torch .Tensor = None ,
183+ timestep : torch .LongTensor = None ,
184+ img_shapes : Optional [List [Tuple [int , int , int ]]] = None ,
185+ txt_seq_lens : Optional [List [int ]] = None ,
186+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
187+ return_dict : bool = True ,
188+ ) -> Union [torch .FloatTensor , Transformer2DModelOutput ]:
189+ """
190+ The [`FluxTransformer2DModel`] forward method.
191+
192+ Args:
193+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
194+ Input `hidden_states`.
195+ controlnet_cond (`torch.Tensor`):
196+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
197+ conditioning_scale (`float`, defaults to `1.0`):
198+ The scale factor for ControlNet outputs.
199+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
200+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
201+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
202+ from the embeddings of input conditions.
203+ timestep ( `torch.LongTensor`):
204+ Used to indicate denoising step.
205+ block_controlnet_hidden_states: (`list` of `torch.Tensor`):
206+ A list of tensors that if specified are added to the residuals of transformer blocks.
207+ joint_attention_kwargs (`dict`, *optional*):
208+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
209+ `self.processor` in
210+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
211+ return_dict (`bool`, *optional*, defaults to `True`):
212+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
213+ tuple.
214+
215+ Returns:
216+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
217+ `tuple` where the first element is the sample tensor.
218+ """
219+ if joint_attention_kwargs is not None :
220+ joint_attention_kwargs = joint_attention_kwargs .copy ()
221+ lora_scale = joint_attention_kwargs .pop ("scale" , 1.0 )
222+ else :
223+ lora_scale = 1.0
224+
225+ if USE_PEFT_BACKEND :
226+ # weight the lora layers by setting `lora_scale` for each PEFT layer
227+ scale_lora_layers (self , lora_scale )
228+ else :
229+ if joint_attention_kwargs is not None and joint_attention_kwargs .get ("scale" , None ) is not None :
230+ logger .warning (
231+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
232+ )
233+ hidden_states = self .img_in (hidden_states )
234+
235+ # add
236+ hidden_states = hidden_states + self .controlnet_x_embedder (controlnet_cond )
237+
238+ temb = self .time_text_embed (timestep , hidden_states )
239+
240+ image_rotary_emb = self .pos_embed (img_shapes , txt_seq_lens , device = hidden_states .device )
241+
242+ timestep = timestep .to (hidden_states .dtype )
243+ encoder_hidden_states = self .txt_norm (encoder_hidden_states )
244+ encoder_hidden_states = self .txt_in (encoder_hidden_states )
245+
246+ block_samples = ()
247+ for index_block , block in enumerate (self .transformer_blocks ):
248+ if torch .is_grad_enabled () and self .gradient_checkpointing :
249+ encoder_hidden_states , hidden_states = self ._gradient_checkpointing_func (
250+ block ,
251+ hidden_states ,
252+ encoder_hidden_states ,
253+ encoder_hidden_states_mask ,
254+ temb ,
255+ image_rotary_emb ,
256+ )
257+
258+ else :
259+ encoder_hidden_states , hidden_states = block (
260+ hidden_states = hidden_states ,
261+ encoder_hidden_states = encoder_hidden_states ,
262+ encoder_hidden_states_mask = encoder_hidden_states_mask ,
263+ temb = temb ,
264+ image_rotary_emb = image_rotary_emb ,
265+ joint_attention_kwargs = joint_attention_kwargs ,
266+ )
267+ block_samples = block_samples + (hidden_states ,)
268+
269+ # controlnet block
270+ controlnet_block_samples = ()
271+ for block_sample , controlnet_block in zip (block_samples , self .controlnet_blocks ):
272+ block_sample = controlnet_block (block_sample )
273+ controlnet_block_samples = controlnet_block_samples + (block_sample ,)
274+
275+ # scaling
276+ controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples ]
277+ controlnet_block_samples = None if len (controlnet_block_samples ) == 0 else controlnet_block_samples
278+
279+ if USE_PEFT_BACKEND :
280+ # remove `lora_scale` from each PEFT layer
281+ unscale_lora_layers (self , lora_scale )
282+
283+ if not return_dict :
284+ return (controlnet_block_samples )
285+
286+ return QwenImageControlNetOutput (
287+ controlnet_block_samples = controlnet_block_samples ,
288+ )
289+
290+
291+ class QwenImageMultiControlNetModel (ModelMixin ):
292+ r"""
293+ `QwenImageMultiControlNetModel` wrapper class for Multi-QwenImageControlNetModel
294+
295+ This module is a wrapper for multiple instances of the `QwenImageControlNetModel`. The `forward()` API is designed to be
296+ compatible with `QwenImageControlNetModel`.
297+
298+ Args:
299+ controlnets (`List[QwenImageControlNetModel]`):
300+ Provides additional conditioning to the unet during the denoising process. You must set multiple
301+ `QwenImageControlNetModel` as a list.
302+ """
303+
304+ def __init__ (self , controlnets ):
305+ super ().__init__ ()
306+ self .nets = nn .ModuleList (controlnets )
307+
308+ def forward (
309+ self ,
310+ hidden_states : torch .FloatTensor ,
311+ controlnet_cond : List [torch .tensor ],
312+ conditioning_scale : List [float ],
313+ encoder_hidden_states : torch .Tensor = None ,
314+ encoder_hidden_states_mask : torch .Tensor = None ,
315+ timestep : torch .LongTensor = None ,
316+ img_shapes : Optional [List [Tuple [int , int , int ]]] = None ,
317+ txt_seq_lens : Optional [List [int ]] = None ,
318+ joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
319+ return_dict : bool = True ,
320+ ) -> Union [QwenImageControlNetOutput , Tuple ]:
321+ # ControlNet-Union with multiple conditions
322+ # only load one ControlNet for saving memories
323+ if len (self .nets ) == 1 :
324+ controlnet = self .nets [0 ]
325+
326+ for i , (image , scale ) in enumerate (zip (controlnet_cond , conditioning_scale )):
327+ block_samples = controlnet (
328+ hidden_states = hidden_states ,
329+ controlnet_cond = image ,
330+ conditioning_scale = scale ,
331+ encoder_hidden_states = encoder_hidden_states ,
332+ encoder_hidden_states_mask = encoder_hidden_states_mask ,
333+ timestep = timestep ,
334+ img_shapes = img_shapes ,
335+ txt_seq_lens = txt_seq_lens ,
336+ joint_attention_kwargs = joint_attention_kwargs ,
337+ return_dict = return_dict ,
338+ )
339+
340+ # merge samples
341+ if i == 0 :
342+ control_block_samples = block_samples
343+ else :
344+ if block_samples is not None and control_block_samples is not None :
345+ control_block_samples = [
346+ control_block_sample + block_sample
347+ for control_block_sample , block_sample in zip (control_block_samples , block_samples )
348+ ]
349+ else :
350+ raise ValueError ("QwenImageMultiControlNetModel only supports controlnet-union now." )
351+
352+ return control_block_samples
0 commit comments