2020
2121from ...configuration_utils import ConfigMixin , register_to_config
2222from ...loaders import PeftAdapterMixin
23- from ...models .normalization import RMSNorm
2423from ..controlnets .controlnet import zero_module
2524from ..modeling_utils import ModelMixin
26- from ..transformers .transformer_z_image import ZImageTransformer2DModel , ZImageTransformerBlock , RopeEmbedder , TimestepEmbedder , SEQ_MULTI_OF , ADALN_EMBED_DIM
25+ from ..transformers .transformer_z_image import (
26+ SEQ_MULTI_OF ,
27+ ZImageTransformer2DModel ,
28+ ZImageTransformerBlock ,
29+ )
2730
2831
2932class ZImageControlTransformerBlock (ZImageTransformerBlock ):
3033 def __init__ (
31- self ,
34+ self ,
3235 layer_id : int ,
3336 dim : int ,
3437 n_heads : int ,
3538 n_kv_heads : int ,
3639 norm_eps : float ,
3740 qk_norm : bool ,
3841 modulation = True ,
39- block_id = 0
42+ block_id = 0 ,
4043 ):
4144 super ().__init__ (layer_id , dim , n_heads , n_kv_heads , norm_eps , qk_norm , modulation )
4245 self .block_id = block_id
@@ -57,7 +60,8 @@ def forward(self, c: torch.Tensor, x: torch.Tensor, **kwargs):
5760 all_c += [c_skip , c ]
5861 c = torch .stack (all_c )
5962 return c
60-
63+
64+
6165class ZImageControlNetModel (ModelMixin , ConfigMixin , PeftAdapterMixin ):
6266 _supports_gradient_checkpointing = True
6367
@@ -72,7 +76,7 @@ def __init__(
7276 n_kv_heads = 30 ,
7377 norm_eps = 1e-5 ,
7478 qk_norm = True ,
75- control_layers_places : List [int ]= None ,
79+ control_layers_places : List [int ] = None ,
7680 control_in_dim = None ,
7781 ):
7882 super ().__init__ ()
@@ -84,15 +88,7 @@ def __init__(
8488 # control blocks
8589 self .control_layers = nn .ModuleList (
8690 [
87- ZImageControlTransformerBlock (
88- i ,
89- dim ,
90- n_heads ,
91- n_kv_heads ,
92- norm_eps ,
93- qk_norm ,
94- block_id = i
95- )
91+ ZImageControlTransformerBlock (i , dim , n_heads , n_kv_heads , norm_eps , qk_norm , block_id = i )
9692 for i in self .control_layers_places
9793 ]
9894 )
@@ -425,7 +421,9 @@ def forward(
425421
426422 if torch .is_grad_enabled () and self .gradient_checkpointing :
427423 for layer in self .control_noise_refiner :
428- control_context = self ._gradient_checkpointing_func (layer , control_context , x_attn_mask , x_freqs_cis , adaln_input )
424+ control_context = self ._gradient_checkpointing_func (
425+ layer , control_context , x_attn_mask , x_freqs_cis , adaln_input
426+ )
429427 else :
430428 for layer in self .control_noise_refiner :
431429 control_context = layer (control_context , x_attn_mask , x_freqs_cis , adaln_input )
@@ -440,14 +438,14 @@ def forward(
440438 control_context_unified = pad_sequence (control_context_unified , batch_first = True , padding_value = 0.0 )
441439 c = control_context_unified
442440
443- new_kwargs = dict ( x = unified , attn_mask = unified_attn_mask , freqs_cis = unified_freqs_cis , adaln_input = adaln_input )
444-
441+ new_kwargs = { "x" : unified , " attn_mask" : unified_attn_mask , " freqs_cis" : unified_freqs_cis , " adaln_input" : adaln_input }
442+
445443 for layer in self .control_layers :
446444 if torch .is_grad_enabled () and self .gradient_checkpointing :
447445 c = self ._gradient_checkpointing_func (layer , c , ** new_kwargs )
448446 else :
449447 c = layer (c , ** new_kwargs )
450-
448+
451449 hints = torch .unbind (c )[:- 1 ] * conditioning_scale
452450 controlnet_block_samples = {}
453451 for layer_idx in range (self .n_layers ):
0 commit comments