@@ -41,6 +41,11 @@ def __init__(
4141 ffn_dim_multiplier : float = (8.0 / 3.0 ),
4242 norm_eps : float = 1e-5 ,
4343 qk_norm : bool = True ,
44+ n_control_layers = 6 ,
45+ control_in_dim = 16 ,
46+ additional_in_dim = 0 ,
47+ broken = False ,
48+ refiner_control = False ,
4449 dtype = None ,
4550 device = None ,
4651 operations = None ,
@@ -49,10 +54,11 @@ def __init__(
4954 super ().__init__ ()
5055 operation_settings = {"operations" : operations , "device" : device , "dtype" : dtype }
5156
52- self .additional_in_dim = 0
53- self .control_in_dim = 16
57+ self .broken = broken
58+ self .additional_in_dim = additional_in_dim
59+ self .control_in_dim = control_in_dim
5460 n_refiner_layers = 2
55- self .n_control_layers = 6
61+ self .n_control_layers = n_control_layers
5662 self .control_layers = nn .ModuleList (
5763 [
5864 ZImageControlTransformerBlock (
@@ -74,28 +80,49 @@ def __init__(
7480 all_x_embedder = {}
7581 patch_size = 2
7682 f_patch_size = 1
77- x_embedder = operations .Linear (f_patch_size * patch_size * patch_size * self .control_in_dim , dim , bias = True , device = device , dtype = dtype )
83+ x_embedder = operations .Linear (f_patch_size * patch_size * patch_size * ( self .control_in_dim + self . additional_in_dim ) , dim , bias = True , device = device , dtype = dtype )
7884 all_x_embedder [f"{ patch_size } -{ f_patch_size } " ] = x_embedder
7985
86+ self .refiner_control = refiner_control
87+
8088 self .control_all_x_embedder = nn .ModuleDict (all_x_embedder )
81- self .control_noise_refiner = nn .ModuleList (
82- [
83- JointTransformerBlock (
84- layer_id ,
85- dim ,
86- n_heads ,
87- n_kv_heads ,
88- multiple_of ,
89- ffn_dim_multiplier ,
90- norm_eps ,
91- qk_norm ,
92- modulation = True ,
93- z_image_modulation = True ,
94- operation_settings = operation_settings ,
95- )
96- for layer_id in range (n_refiner_layers )
97- ]
98- )
89+ if self .refiner_control :
90+ self .control_noise_refiner = nn .ModuleList (
91+ [
92+ ZImageControlTransformerBlock (
93+ layer_id ,
94+ dim ,
95+ n_heads ,
96+ n_kv_heads ,
97+ multiple_of ,
98+ ffn_dim_multiplier ,
99+ norm_eps ,
100+ qk_norm ,
101+ block_id = layer_id ,
102+ operation_settings = operation_settings ,
103+ )
104+ for layer_id in range (n_refiner_layers )
105+ ]
106+ )
107+ else :
108+ self .control_noise_refiner = nn .ModuleList (
109+ [
110+ JointTransformerBlock (
111+ layer_id ,
112+ dim ,
113+ n_heads ,
114+ n_kv_heads ,
115+ multiple_of ,
116+ ffn_dim_multiplier ,
117+ norm_eps ,
118+ qk_norm ,
119+ modulation = True ,
120+ z_image_modulation = True ,
121+ operation_settings = operation_settings ,
122+ )
123+ for layer_id in range (n_refiner_layers )
124+ ]
125+ )
99126
100127 def forward (self , cap_feats , control_context , x_freqs_cis , adaln_input ):
101128 patch_size = 2
@@ -105,9 +132,29 @@ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
105132 control_context = self .control_all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](control_context .view (B , C , H // pH , pH , W // pW , pW ).permute (0 , 2 , 4 , 3 , 5 , 1 ).flatten (3 ).flatten (1 , 2 ))
106133
107134 x_attn_mask = None
108- for layer in self .control_noise_refiner :
109- control_context = layer (control_context , x_attn_mask , x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input )
135+ if not self .refiner_control :
136+ for layer in self .control_noise_refiner :
137+ control_context = layer (control_context , x_attn_mask , x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input )
138+
110139 return control_context
111140
141+ def forward_noise_refiner_block (self , layer_id , control_context , x , x_attn_mask , x_freqs_cis , adaln_input ):
142+ if self .refiner_control :
143+ if self .broken :
144+ if layer_id == 0 :
145+ return self .control_layers [layer_id ](control_context , x , x_mask = x_attn_mask , freqs_cis = x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input = adaln_input )
146+ if layer_id > 0 :
147+ out = None
148+ for i in range (1 , len (self .control_layers )):
149+ o , control_context = self .control_layers [i ](control_context , x , x_mask = x_attn_mask , freqs_cis = x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input = adaln_input )
150+ if out is None :
151+ out = o
152+
153+ return (out , control_context )
154+ else :
155+ return self .control_noise_refiner [layer_id ](control_context , x , x_mask = x_attn_mask , freqs_cis = x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input = adaln_input )
156+ else :
157+ return (None , control_context )
158+
112159 def forward_control_block (self , layer_id , control_context , x , x_attn_mask , x_freqs_cis , adaln_input ):
113160 return self .control_layers [layer_id ](control_context , x , x_mask = x_attn_mask , freqs_cis = x_freqs_cis [:control_context .shape [0 ], :control_context .shape [1 ]], adaln_input = adaln_input )
0 commit comments