Skip to content

Commit 7df350d

Browse files
committed
modified main model forward, freqs_cis left
1 parent 3e74bb2 commit 7df350d

File tree

1 file changed

+95
-101
lines changed

1 file changed

+95
-101
lines changed

src/diffusers/models/transformers/transformer_z_image.py

Lines changed: 95 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
except ImportError:
2828
flash_attn_varlen_func = None
2929

30+
# todo see how other teams do this
3031
try:
3132
from apex.normalization import FusedRMSNorm as RMSNorm
3233
except ImportError:
@@ -61,10 +62,6 @@ def __init__(self, out_size, mid_size=None, frequency_embedding_size=256):
6162
bias=True,
6263
),
6364
)
64-
nn.init.normal_(self.mlp[0].weight, std=0.02)
65-
nn.init.zeros_(self.mlp[0].bias)
66-
nn.init.normal_(self.mlp[2].weight, std=0.02)
67-
nn.init.zeros_(self.mlp[2].bias)
6865

6966
self.frequency_embedding_size = frequency_embedding_size
7067

@@ -573,9 +570,9 @@ def patchify_and_embed(
573570
all_cap_pad_mask = []
574571
all_cap_feats_out = []
575572

576-
for i, image in enumerate(all_image):
577-
### LLM Text Encoder
578-
cap_ori_len = len(all_cap_feats[i])
573+
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
574+
### Process Caption
575+
cap_ori_len = len(cap_feat)
579576
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
580577
# padded position ids
581578
cap_padded_pos_ids = self.create_coordinate_grid(
@@ -596,7 +593,7 @@ def patchify_and_embed(
596593
)
597594
# padded feature
598595
cap_padded_feat = torch.cat(
599-
[all_cap_feats[i], all_cap_feats[i][-1:].repeat(cap_padding_len, 1)],
596+
[cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)],
600597
dim=0,
601598
)
602599
all_cap_feats_out.append(cap_padded_feat)
@@ -677,126 +674,123 @@ def forward(
677674
x_size,
678675
x_pos_ids,
679676
cap_pos_ids,
680-
x_pad_mask,
681-
cap_pad_mask,
677+
x_inner_pad_mask,
678+
cap_inner_pad_mask,
682679
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)
683680

684681
# x embed & refine
685682
x_item_seqlens = [len(_) for _ in x]
686683
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
687684
x_max_item_seqlen = max(x_item_seqlens)
688-
x_cu_seqlens = F.pad(
689-
torch.cumsum(
690-
torch.tensor(x_item_seqlens, dtype=torch.int32, device=device),
691-
dim=0,
692-
dtype=torch.int32,
693-
),
694-
(1, 0),
685+
686+
x = torch.cat(x, dim=0)
687+
x = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
688+
x[torch.cat(x_inner_pad_mask)] = self.x_pad_token
689+
x = x.split(x_item_seqlens, dim=0)
690+
x_freqs_cis = self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0) # todo
691+
692+
pad_tensor = torch.zeros(
693+
(1, self.dim),
694+
dtype=x[0].dtype,
695+
device=device,
695696
)
696-
x_src_ids = [
697-
torch.full((count,), i, dtype=torch.int32, device=device) for i, count in enumerate(x_item_seqlens)
698-
]
699-
x_freqs_cis = self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)
700-
701-
x_shard = torch.cat(x, dim=0)
702-
x_src_ids_shard = torch.cat(x_src_ids, dim=0)
703-
x_freqs_cis_shard = torch.cat(x_freqs_cis, dim=0)
704-
x_pad_mask_shard = torch.cat(x_pad_mask, dim=0)
705-
del x
706-
707-
x_shard = self.all_x_embedder[f"{patch_size}-{f_patch_size}"](x_shard)
708-
x_shard[x_pad_mask_shard] = self.x_pad_token
697+
x_pad_mask = torch.zeros(
698+
(bsz, x_max_item_seqlen),
699+
dtype=torch.bool,
700+
device=device
701+
)
702+
for i, item in enumerate(x):
703+
seq_len = x_item_seqlens[i]
704+
x[i] = torch.cat([item, pad_tensor.repeat(x_max_item_seqlen - seq_len, 1)])
705+
x_pad_mask[i, seq_len:] = 1
706+
x = torch.stack(x)
707+
709708
for layer in self.noise_refiner:
710-
x_shard = layer(
711-
x_shard,
712-
x_src_ids_shard,
713-
x_freqs_cis_shard,
714-
x_cu_seqlens,
715-
x_max_item_seqlen,
709+
x = layer(
710+
x,
711+
x_pad_mask,
712+
x_freqs_cis,
716713
adaln_input,
717-
)
718-
x_flatten = x_shard
714+
) # todo
719715

720716
# cap embed & refine
721717
cap_item_seqlens = [len(_) for _ in cap_feats]
722718
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
723719
cap_max_item_seqlen = max(cap_item_seqlens)
724-
cap_cu_seqlens = F.pad(
725-
torch.cumsum(
726-
torch.tensor(cap_item_seqlens, dtype=torch.int32, device=device),
727-
dim=0,
728-
dtype=torch.int32,
729-
),
730-
(1, 0),
720+
721+
cap_feats = torch.cat(cap_feats, dim=0)
722+
cap_feats = self.cap_embedder(cap_feats)
723+
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
724+
cap_feats = cap_feats.split(cap_item_seqlens, dim=0)
725+
cap_freqs_cis = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0) # todo
726+
727+
pad_tensor = torch.zeros(
728+
(1, self.dim),
729+
dtype=x[0].dtype,
730+
device=device,
731731
)
732-
cap_src_ids = [
733-
torch.full((count,), i, dtype=torch.int32, device=device) for i, count in enumerate(cap_item_seqlens)
734-
]
735-
cap_freqs_cis = self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)
736-
737-
cap_shard = torch.cat(cap_feats, dim=0)
738-
cap_src_ids_shard = torch.cat(cap_src_ids, dim=0)
739-
cap_freqs_cis_shard = torch.cat(cap_freqs_cis, dim=0)
740-
cap_pad_mask_shard = torch.cat(cap_pad_mask, dim=0)
741-
del cap_feats
742-
743-
cap_shard = self.cap_embedder(cap_shard)
744-
cap_shard[cap_pad_mask_shard] = self.cap_pad_token
732+
cap_pad_mask = torch.zeros(
733+
(bsz, cap_max_item_seqlen),
734+
dtype=torch.bool,
735+
device=device
736+
)
737+
for i, item in enumerate(cap_feats):
738+
seq_len = cap_item_seqlens[i]
739+
cap_feats[i] = torch.cat([item, pad_tensor.repeat(cap_max_item_seqlen - seq_len, 1)])
740+
cap_pad_mask[i, seq_len:] = 1
741+
cap_feats = torch.stack(cap_feats)
745742
for layer in self.context_refiner:
746-
cap_shard = layer(
747-
cap_shard,
748-
cap_src_ids_shard,
749-
cap_freqs_cis_shard,
750-
cap_cu_seqlens,
751-
cap_max_item_seqlen,
743+
cap_feats = layer(
744+
cap_feats,
745+
cap_pad_mask,
746+
cap_freqs_cis,
752747
)
753-
cap_flatten = cap_shard
754-
755-
# unified
756-
def merge_interleave(l1, l2):
757-
return list(itertools.chain(*zip(l1, l2)))
758748

759-
unified = torch.cat(
760-
merge_interleave(
761-
cap_flatten.split(cap_item_seqlens, dim=0),
762-
x_flatten.split(x_item_seqlens, dim=0),
763-
),
764-
dim=0,
765-
)
749+
# unified todo
766750
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
767-
assert len(unified) == sum(unified_item_seqlens)
768751
unified_max_item_seqlen = max(unified_item_seqlens)
769-
unified_cu_seqlens = F.pad(
770-
torch.cumsum(
771-
torch.tensor(unified_item_seqlens, dtype=torch.int32, device=device),
772-
dim=0,
773-
dtype=torch.int32,
774-
),
775-
(1, 0),
752+
753+
pad_tensor = torch.zeros(
754+
(1, self.dim),
755+
dtype=x[0].dtype,
756+
device=device,
776757
)
777-
unified_src_ids = torch.cat(merge_interleave(cap_src_ids, x_src_ids))
778-
unified_freqs_cis = torch.cat(merge_interleave(cap_freqs_cis, x_freqs_cis))
758+
unified_pad_mask = torch.zeros(
759+
(bsz, unified_max_item_seqlen),
760+
dtype=torch.bool,
761+
device=device
762+
)
763+
764+
unified = []
765+
for i in range(bsz):
766+
x_len = x_item_seqlens[i]
767+
cap_len = cap_item_seqlens[i]
768+
unified.append(
769+
torch.cat(
770+
[
771+
x[i][:x_item_seqlens[i]],
772+
cap_feats[i][:cap_item_seqlens[i]],
773+
pad_tensor.repeat(unified_max_item_seqlen - x_len - cap_len, 1)
774+
]
775+
)
776+
)
777+
unified_pad_mask[i, x_len + cap_len:] = 1
778+
779+
unified_freqs_cis = torch.cat(merge_interleave(cap_freqs_cis, x_freqs_cis)) # todo
779780

780-
unified_shard = unified
781-
unified_src_ids_shard = unified_src_ids
782-
unified_freqs_cis_shard = unified_freqs_cis
783781
for layer in self.layers:
784782
unified_shard = layer(
785-
unified_shard,
786-
unified_src_ids_shard,
787-
unified_freqs_cis_shard,
788-
unified_cu_seqlens,
789-
unified_max_item_seqlen,
783+
unified,
784+
unified_pad_mask,
785+
unified_freqs_cis,
790786
adaln_input,
791787
)
792-
unified_shard = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
793-
unified_shard, unified_src_ids_shard, adaln_input
794-
)
795-
unified = unified_shard.split(unified_item_seqlens, dim=0)
796-
x = [unified[i][cap_item_seqlens[i] :] for i in range(bsz)]
797-
assert all(len(x[i]) == x_item_seqlens[i] for i in range(bsz))
798788

799-
x = self.unpatchify(x, x_size, patch_size, f_patch_size)
789+
unified = self.all_final_layer[f"{patch_size}-{f_patch_size}"](
790+
unified, adaln_input # todo
791+
)
792+
unified = unified.split(unified_item_seqlens, dim=0)
793+
x = self.unpatchify(unified, x_size, patch_size, f_patch_size)
800794

801795
return x, {}
802796

0 commit comments

Comments
 (0)