2323from ...models .normalization import RMSNorm
2424from ..controlnets .controlnet import zero_module
2525from ..modeling_utils import ModelMixin
26- from ..transformers .transformer_z_image import ZImageTransformerBlock , RopeEmbedder , TimestepEmbedder , SEQ_MULTI_OF , ADALN_EMBED_DIM
26+ from ..transformers .transformer_z_image import ZImageTransformer2DModel , ZImageTransformerBlock , RopeEmbedder , TimestepEmbedder , SEQ_MULTI_OF , ADALN_EMBED_DIM
2727
2828
2929class ZImageControlTransformerBlock (ZImageTransformerBlock ):
@@ -66,87 +66,16 @@ def __init__(
6666 self ,
6767 all_patch_size = (2 ,),
6868 all_f_patch_size = (1 ,),
69- in_channels = 16 ,
7069 dim = 3840 ,
71- n_layers = 30 ,
7270 n_refiner_layers = 2 ,
7371 n_heads = 30 ,
7472 n_kv_heads = 30 ,
7573 norm_eps = 1e-5 ,
7674 qk_norm = True ,
77- cap_feat_dim = 2560 ,
78- rope_theta = 256.0 ,
79- t_scale = 1000.0 ,
80- axes_dims = [32 , 48 , 48 ],
81- axes_lens = [1024 , 512 , 512 ],
8275 control_layers_places : List [int ]= None ,
8376 control_in_dim = None ,
8477 ):
8578 super ().__init__ ()
86- self .in_channels = in_channels
87- self .out_channels = in_channels
88- self .all_patch_size = all_patch_size
89- self .all_f_patch_size = all_f_patch_size
90- self .dim = dim
91- self .n_heads = n_heads
92-
93- self .rope_theta = rope_theta
94- self .t_scale = t_scale
95- self .gradient_checkpointing = False
96- self .n_layers = n_layers
97-
98- assert len (all_patch_size ) == len (all_f_patch_size )
99-
100- all_x_embedder = {}
101- for patch_idx , (patch_size , f_patch_size ) in enumerate (zip (all_patch_size , all_f_patch_size )):
102- x_embedder = nn .Linear (f_patch_size * patch_size * patch_size * in_channels , dim , bias = True )
103- all_x_embedder [f"{ patch_size } -{ f_patch_size } " ] = x_embedder
104-
105- self .all_x_embedder = nn .ModuleDict (all_x_embedder )
106- self .noise_refiner = nn .ModuleList (
107- [
108- ZImageTransformerBlock (
109- 1000 + layer_id ,
110- dim ,
111- n_heads ,
112- n_kv_heads ,
113- norm_eps ,
114- qk_norm ,
115- modulation = True ,
116- )
117- for layer_id in range (n_refiner_layers )
118- ]
119- )
120- self .context_refiner = nn .ModuleList (
121- [
122- ZImageTransformerBlock (
123- layer_id ,
124- dim ,
125- n_heads ,
126- n_kv_heads ,
127- norm_eps ,
128- qk_norm ,
129- modulation = False ,
130- )
131- for layer_id in range (n_refiner_layers )
132- ]
133- )
134- self .t_embedder = TimestepEmbedder (min (dim , ADALN_EMBED_DIM ), mid_size = 1024 )
135- self .cap_embedder = nn .Sequential (
136- RMSNorm (cap_feat_dim , eps = norm_eps ),
137- nn .Linear (cap_feat_dim , dim , bias = True ),
138- )
139-
140- self .x_pad_token = nn .Parameter (torch .empty ((1 , dim )))
141- self .cap_pad_token = nn .Parameter (torch .empty ((1 , dim )))
142-
143- self .axes_dims = axes_dims
144- self .axes_lens = axes_lens
145-
146- self .rope_embedder = RopeEmbedder (theta = rope_theta , axes_dims = axes_dims , axes_lens = axes_lens )
147-
148- ## Original Control layers
149-
15079 self .control_layers_places = control_layers_places
15180 self .control_in_dim = control_in_dim
15281
@@ -366,6 +295,7 @@ def patchify_and_embed(
366295
367296 def forward (
368297 self ,
298+ transformer : ZImageTransformer2DModel ,
369299 x : List [torch .Tensor ],
370300 cap_feats : List [torch .Tensor ],
371301 control_context : List [torch .Tensor ],
@@ -380,7 +310,7 @@ def forward(
380310 bsz = len (x )
381311 device = x [0 ].device
382312 t = t * self .t_scale
383- t = self .t_embedder (t )
313+ t = transformer .t_embedder (t )
384314
385315 (
386316 x ,
@@ -398,13 +328,13 @@ def forward(
398328 x_max_item_seqlen = max (x_item_seqlens )
399329
400330 x = torch .cat (x , dim = 0 )
401- x = self .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x )
331+ x = transformer .all_x_embedder [f"{ patch_size } -{ f_patch_size } " ](x )
402332
403333 # Match t_embedder output dtype to x for layerwise casting compatibility
404334 adaln_input = t .type_as (x )
405- x [torch .cat (x_inner_pad_mask )] = self .x_pad_token
335+ x [torch .cat (x_inner_pad_mask )] = transformer .x_pad_token
406336 x = list (x .split (x_item_seqlens , dim = 0 ))
407- x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
337+ x_freqs_cis = list (transformer .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
408338
409339 x = pad_sequence (x , batch_first = True , padding_value = 0.0 )
410340 x_freqs_cis = pad_sequence (x_freqs_cis , batch_first = True , padding_value = 0.0 )
@@ -413,10 +343,10 @@ def forward(
413343 x_attn_mask [i , :seq_len ] = 1
414344
415345 if torch .is_grad_enabled () and self .gradient_checkpointing :
416- for layer in self .noise_refiner :
346+ for layer in transformer .noise_refiner :
417347 x = self ._gradient_checkpointing_func (layer , x , x_attn_mask , x_freqs_cis , adaln_input )
418348 else :
419- for layer in self .noise_refiner :
349+ for layer in transformer .noise_refiner :
420350 x = layer (x , x_attn_mask , x_freqs_cis , adaln_input )
421351
422352 # cap embed & refine
@@ -425,10 +355,10 @@ def forward(
425355 cap_max_item_seqlen = max (cap_item_seqlens )
426356
427357 cap_feats = torch .cat (cap_feats , dim = 0 )
428- cap_feats = self .cap_embedder (cap_feats )
429- cap_feats [torch .cat (cap_inner_pad_mask )] = self .cap_pad_token
358+ cap_feats = transformer .cap_embedder (cap_feats )
359+ cap_feats [torch .cat (cap_inner_pad_mask )] = transformer .cap_pad_token
430360 cap_feats = list (cap_feats .split (cap_item_seqlens , dim = 0 ))
431- cap_freqs_cis = list (self .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
361+ cap_freqs_cis = list (transformer .rope_embedder (torch .cat (cap_pos_ids , dim = 0 )).split (cap_item_seqlens , dim = 0 ))
432362
433363 cap_feats = pad_sequence (cap_feats , batch_first = True , padding_value = 0.0 )
434364 cap_freqs_cis = pad_sequence (cap_freqs_cis , batch_first = True , padding_value = 0.0 )
@@ -437,10 +367,10 @@ def forward(
437367 cap_attn_mask [i , :seq_len ] = 1
438368
439369 if torch .is_grad_enabled () and self .gradient_checkpointing :
440- for layer in self .context_refiner :
370+ for layer in transformer .context_refiner :
441371 cap_feats = self ._gradient_checkpointing_func (layer , cap_feats , cap_attn_mask , cap_freqs_cis )
442372 else :
443- for layer in self .context_refiner :
373+ for layer in transformer .context_refiner :
444374 cap_feats = layer (cap_feats , cap_attn_mask , cap_freqs_cis )
445375
446376 # unified
@@ -485,7 +415,7 @@ def forward(
485415 adaln_input = t .type_as (control_context )
486416 control_context [torch .cat (x_inner_pad_mask )] = self .x_pad_token
487417 control_context = list (control_context .split (x_item_seqlens , dim = 0 ))
488- x_freqs_cis = list (self .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
418+ x_freqs_cis = list (transformer .rope_embedder (torch .cat (x_pos_ids , dim = 0 )).split (x_item_seqlens , dim = 0 ))
489419
490420 control_context = pad_sequence (control_context , batch_first = True , padding_value = 0.0 )
491421 x_freqs_cis = pad_sequence (x_freqs_cis , batch_first = True , padding_value = 0.0 )
0 commit comments