Skip to content

Commit 1e2009d

Browse files
committed
passed transformer
1 parent 2354fda commit 1e2009d

File tree

2 files changed

+15
-84
lines changed

2 files changed

+15
-84
lines changed

src/diffusers/models/controlnets/controlnet_z_image.py

Lines changed: 14 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from ...models.normalization import RMSNorm
2424
from ..controlnets.controlnet import zero_module
2525
from ..modeling_utils import ModelMixin
26-
from ..transformers.transformer_z_image import ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM
26+
from ..transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock, RopeEmbedder, TimestepEmbedder, SEQ_MULTI_OF, ADALN_EMBED_DIM
2727

2828

2929
class ZImageControlTransformerBlock(ZImageTransformerBlock):
@@ -66,87 +66,16 @@ def __init__(
6666
self,
6767
all_patch_size=(2,),
6868
all_f_patch_size=(1,),
69-
in_channels=16,
7069
dim=3840,
71-
n_layers=30,
7270
n_refiner_layers=2,
7371
n_heads=30,
7472
n_kv_heads=30,
7573
norm_eps=1e-5,
7674
qk_norm=True,
77-
cap_feat_dim=2560,
78-
rope_theta=256.0,
79-
t_scale=1000.0,
80-
axes_dims=[32, 48, 48],
81-
axes_lens=[1024, 512, 512],
8275
control_layers_places: List[int]=None,
8376
control_in_dim=None,
8477
):
8578
super().__init__()
86-
self.in_channels = in_channels
87-
self.out_channels = in_channels
88-
self.all_patch_size = all_patch_size
89-
self.all_f_patch_size = all_f_patch_size
90-
self.dim = dim
91-
self.n_heads = n_heads
92-
93-
self.rope_theta = rope_theta
94-
self.t_scale = t_scale
95-
self.gradient_checkpointing = False
96-
self.n_layers = n_layers
97-
98-
assert len(all_patch_size) == len(all_f_patch_size)
99-
100-
all_x_embedder = {}
101-
for patch_idx, (patch_size, f_patch_size) in enumerate(zip(all_patch_size, all_f_patch_size)):
102-
x_embedder = nn.Linear(f_patch_size * patch_size * patch_size * in_channels, dim, bias=True)
103-
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
104-
105-
self.all_x_embedder = nn.ModuleDict(all_x_embedder)
106-
self.noise_refiner = nn.ModuleList(
107-
[
108-
ZImageTransformerBlock(
109-
1000 + layer_id,
110-
dim,
111-
n_heads,
112-
n_kv_heads,
113-
norm_eps,
114-
qk_norm,
115-
modulation=True,
116-
)
117-
for layer_id in range(n_refiner_layers)
118-
]
119-
)
120-
self.context_refiner = nn.ModuleList(
121-
[
122-
ZImageTransformerBlock(
123-
layer_id,
124-
dim,
125-
n_heads,
126-
n_kv_heads,
127-
norm_eps,
128-
qk_norm,
129-
modulation=False,
130-
)
131-
for layer_id in range(n_refiner_layers)
132-
]
133-
)
134-
self.t_embedder = TimestepEmbedder(min(dim, ADALN_EMBED_DIM), mid_size=1024)
135-
self.cap_embedder = nn.Sequential(
136-
RMSNorm(cap_feat_dim, eps=norm_eps),
137-
nn.Linear(cap_feat_dim, dim, bias=True),
138-
)
139-
140-
self.x_pad_token = nn.Parameter(torch.empty((1, dim)))
141-
self.cap_pad_token = nn.Parameter(torch.empty((1, dim)))
142-
143-
self.axes_dims = axes_dims
144-
self.axes_lens = axes_lens
145-
146-
self.rope_embedder = RopeEmbedder(theta=rope_theta, axes_dims=axes_dims, axes_lens=axes_lens)
147-
148-
## Original Control layers
149-
15079
self.control_layers_places = control_layers_places
15180
self.control_in_dim = control_in_dim
15281

@@ -366,6 +295,7 @@ def patchify_and_embed(
366295

367296
def forward(
368297
self,
298+
transformer: ZImageTransformer2DModel,
369299
x: List[torch.Tensor],
370300
cap_feats: List[torch.Tensor],
371301
control_context: List[torch.Tensor],
@@ -380,7 +310,7 @@ def forward(
380310
bsz = len(x)
381311
device = x[0].device
382312
t = t * self.t_scale
383-
t = self.t_embedder(t)
313+
t = transformer.t_embedder(t)
384314

385315
(
386316
x,
@@ -398,13 +328,13 @@ def forward(
398328
x_max_item_seqlen = max(x_item_seqlens)
399329

400330
x = torch.cat(x, dim=0)
401-
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
331+
x = transformer.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
402332

403333
# Match t_embedder output dtype to x for layerwise casting compatibility
404334
adaln_input = t.type_as(x)
405-
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
335+
x[torch.cat(x_inner_pad_mask)] = transformer.x_pad_token
406336
x = list(x.split(x_item_seqlens, dim=0))
407-
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
337+
x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
408338

409339
x = pad_sequence(x, batch_first=True, padding_value=0.0)
410340
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)
@@ -413,10 +343,10 @@ def forward(
413343
x_attn_mask[i, :seq_len] = 1
414344

415345
if torch.is_grad_enabled() and self.gradient_checkpointing:
416-
for layer in self.noise_refiner:
346+
for layer in transformer.noise_refiner:
417347
x = self._gradient_checkpointing_func(layer, x, x_attn_mask, x_freqs_cis, adaln_input)
418348
else:
419-
for layer in self.noise_refiner:
349+
for layer in transformer.noise_refiner:
420350
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
421351

422352
# cap embed & refine
@@ -425,10 +355,10 @@ def forward(
425355
cap_max_item_seqlen = max(cap_item_seqlens)
426356

427357
cap_feats = torch.cat(cap_feats, dim=0)
428-
cap_feats = self.cap_embedder(cap_feats)
429-
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
358+
cap_feats = transformer.cap_embedder(cap_feats)
359+
cap_feats[torch.cat(cap_inner_pad_mask)] = transformer.cap_pad_token
430360
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
431-
cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
361+
cap_freqs_cis = list(transformer.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
432362

433363
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
434364
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
@@ -437,10 +367,10 @@ def forward(
437367
cap_attn_mask[i, :seq_len] = 1
438368

439369
if torch.is_grad_enabled() and self.gradient_checkpointing:
440-
for layer in self.context_refiner:
370+
for layer in transformer.context_refiner:
441371
cap_feats = self._gradient_checkpointing_func(layer, cap_feats, cap_attn_mask, cap_freqs_cis)
442372
else:
443-
for layer in self.context_refiner:
373+
for layer in transformer.context_refiner:
444374
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
445375

446376
# unified
@@ -485,7 +415,7 @@ def forward(
485415
adaln_input = t.type_as(control_context)
486416
control_context[torch.cat(x_inner_pad_mask)] = self.x_pad_token
487417
control_context = list(control_context.split(x_item_seqlens, dim=0))
488-
x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
418+
x_freqs_cis = list(transformer.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
489419

490420
control_context = pad_sequence(control_context, batch_first=True, padding_value=0.0)
491421
x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0)

src/diffusers/pipelines/z_image/pipeline_z_image_controlnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def __call__(
594594
latent_model_input_list = list(latent_model_input.unbind(dim=0))
595595

596596
controlnet_block_samples = self.controlnet(
597+
self.transformer,
597598
latent_model_input_list,
598599
prompt_embeds_model_input,
599600
control_image,

0 commit comments

Comments
 (0)