|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | import math |
16 | | -from typing import List, Literal, Optional |
| 16 | +from typing import List, Literal, Optional, Tuple |
17 | 17 |
|
18 | 18 | import torch |
19 | 19 | import torch.nn as nn |
@@ -170,6 +170,21 @@ def forward(self, x): |
170 | 170 | return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x))) |
171 | 171 |
|
172 | 172 |
|
| 173 | +# Copied from diffusers.models.transformers.transformer_z_image.select_per_token |
| 174 | +def select_per_token( |
| 175 | + value_noisy: torch.Tensor, |
| 176 | + value_clean: torch.Tensor, |
| 177 | + noise_mask: torch.Tensor, |
| 178 | + seq_len: int, |
| 179 | +) -> torch.Tensor: |
| 180 | + noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1) |
| 181 | + return torch.where( |
| 182 | + noise_mask_expanded == 1, |
| 183 | + value_noisy.unsqueeze(1).expand(-1, seq_len, -1), |
| 184 | + value_clean.unsqueeze(1).expand(-1, seq_len, -1), |
| 185 | + ) |
| 186 | + |
| 187 | + |
173 | 188 | @maybe_allow_in_graph |
174 | 189 | # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock |
175 | 190 | class ZImageTransformerBlock(nn.Module): |
@@ -220,12 +235,37 @@ def forward( |
220 | 235 | attn_mask: torch.Tensor, |
221 | 236 | freqs_cis: torch.Tensor, |
222 | 237 | adaln_input: Optional[torch.Tensor] = None, |
| 238 | + noise_mask: Optional[torch.Tensor] = None, |
| 239 | + adaln_noisy: Optional[torch.Tensor] = None, |
| 240 | + adaln_clean: Optional[torch.Tensor] = None, |
223 | 241 | ): |
224 | 242 | if self.modulation: |
225 | | - assert adaln_input is not None |
226 | | - scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2) |
227 | | - gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() |
228 | | - scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
| 243 | + seq_len = x.shape[1] |
| 244 | + |
| 245 | + if noise_mask is not None: |
| 246 | + # Per-token modulation: different modulation for noisy/clean tokens |
| 247 | + mod_noisy = self.adaLN_modulation(adaln_noisy) |
| 248 | + mod_clean = self.adaLN_modulation(adaln_clean) |
| 249 | + |
| 250 | + scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1) |
| 251 | + scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1) |
| 252 | + |
| 253 | + gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh() |
| 254 | + gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh() |
| 255 | + |
| 256 | + scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy |
| 257 | + scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean |
| 258 | + |
| 259 | + scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len) |
| 260 | + scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len) |
| 261 | + gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len) |
| 262 | + gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len) |
| 263 | + else: |
| 264 | + # Global modulation: same modulation for all tokens (avoid double select) |
| 265 | + mod = self.adaLN_modulation(adaln_input) |
| 266 | + scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2) |
| 267 | + gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh() |
| 268 | + scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp |
229 | 269 |
|
230 | 270 | # Attention block |
231 | 271 | attn_out = self.attention( |
@@ -493,112 +533,93 @@ def from_transformer(cls, controlnet, transformer): |
493 | 533 | def create_coordinate_grid(size, start=None, device=None): |
494 | 534 | if start is None: |
495 | 535 | start = (0 for _ in size) |
496 | | - |
497 | 536 | axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)] |
498 | 537 | grids = torch.meshgrid(axes, indexing="ij") |
499 | 538 | return torch.stack(grids, dim=-1) |
500 | 539 |
|
501 | | - # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed |
502 | | - def patchify_and_embed( |
| 540 | + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image |
| 541 | + def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int): |
| 542 | + """Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim).""" |
| 543 | + pH, pW, pF = patch_size, patch_size, f_patch_size |
| 544 | + C, F, H, W = image.size() |
| 545 | + F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
| 546 | + image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
| 547 | + image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
| 548 | + return image, (F, H, W), (F_tokens, H_tokens, W_tokens) |
| 549 | + |
| 550 | + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids |
| 551 | + def _pad_with_ids( |
503 | 552 | self, |
504 | | - all_image: List[torch.Tensor], |
505 | | - all_cap_feats: List[torch.Tensor], |
506 | | - patch_size: int, |
507 | | - f_patch_size: int, |
| 553 | + feat: torch.Tensor, |
| 554 | + pos_grid_size: Tuple, |
| 555 | + pos_start: Tuple, |
| 556 | + device: torch.device, |
| 557 | + noise_mask_val: Optional[int] = None, |
508 | 558 | ): |
509 | | - pH = pW = patch_size |
510 | | - pF = f_patch_size |
511 | | - device = all_image[0].device |
512 | | - |
513 | | - all_image_out = [] |
514 | | - all_image_size = [] |
515 | | - all_image_pos_ids = [] |
516 | | - all_image_pad_mask = [] |
517 | | - all_cap_pos_ids = [] |
518 | | - all_cap_pad_mask = [] |
519 | | - all_cap_feats_out = [] |
520 | | - |
521 | | - for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)): |
522 | | - ### Process Caption |
523 | | - cap_ori_len = len(cap_feat) |
524 | | - cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF |
525 | | - # padded position ids |
526 | | - cap_padded_pos_ids = self.create_coordinate_grid( |
527 | | - size=(cap_ori_len + cap_padding_len, 1, 1), |
528 | | - start=(1, 0, 0), |
529 | | - device=device, |
530 | | - ).flatten(0, 2) |
531 | | - all_cap_pos_ids.append(cap_padded_pos_ids) |
532 | | - # pad mask |
533 | | - cap_pad_mask = torch.cat( |
534 | | - [ |
535 | | - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), |
536 | | - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), |
537 | | - ], |
538 | | - dim=0, |
| 559 | + """Pad feature to SEQ_MULTI_OF, create position IDs and pad mask.""" |
| 560 | + ori_len = len(feat) |
| 561 | + pad_len = (-ori_len) % SEQ_MULTI_OF |
| 562 | + total_len = ori_len + pad_len |
| 563 | + |
| 564 | + # Pos IDs |
| 565 | + ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2) |
| 566 | + if pad_len > 0: |
| 567 | + pad_pos_ids = ( |
| 568 | + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) |
| 569 | + .flatten(0, 2) |
| 570 | + .repeat(pad_len, 1) |
539 | 571 | ) |
540 | | - all_cap_pad_mask.append( |
541 | | - cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) |
| 572 | + pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0) |
| 573 | + padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0) |
| 574 | + pad_mask = torch.cat( |
| 575 | + [ |
| 576 | + torch.zeros(ori_len, dtype=torch.bool, device=device), |
| 577 | + torch.ones(pad_len, dtype=torch.bool, device=device), |
| 578 | + ] |
542 | 579 | ) |
| 580 | + else: |
| 581 | + pos_ids = ori_pos_ids |
| 582 | + padded_feat = feat |
| 583 | + pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device) |
543 | 584 |
|
544 | | - # padded feature |
545 | | - cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0) |
546 | | - all_cap_feats_out.append(cap_padded_feat) |
547 | | - |
548 | | - ### Process Image |
549 | | - C, F, H, W = image.size() |
550 | | - all_image_size.append((F, H, W)) |
551 | | - F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW |
552 | | - |
553 | | - image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW) |
554 | | - # "c f pf h ph w pw -> (f h w) (pf ph pw c)" |
555 | | - image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C) |
| 585 | + noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level |
| 586 | + return padded_feat, pos_ids, pad_mask, total_len, noise_mask |
556 | 587 |
|
557 | | - image_ori_len = len(image) |
558 | | - image_padding_len = (-image_ori_len) % SEQ_MULTI_OF |
| 588 | + # Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed |
| 589 | + def patchify_and_embed( |
| 590 | + self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int |
| 591 | + ): |
| 592 | + """Patchify for basic mode: single image per batch item.""" |
| 593 | + device = all_image[0].device |
| 594 | + all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], [] |
| 595 | + all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], [] |
559 | 596 |
|
560 | | - image_ori_pos_ids = self.create_coordinate_grid( |
561 | | - size=(F_tokens, H_tokens, W_tokens), |
562 | | - start=(cap_ori_len + cap_padding_len + 1, 0, 0), |
563 | | - device=device, |
564 | | - ).flatten(0, 2) |
565 | | - image_padded_pos_ids = torch.cat( |
566 | | - [ |
567 | | - image_ori_pos_ids, |
568 | | - self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) |
569 | | - .flatten(0, 2) |
570 | | - .repeat(image_padding_len, 1), |
571 | | - ], |
572 | | - dim=0, |
573 | | - ) |
574 | | - all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) |
575 | | - # pad mask |
576 | | - image_pad_mask = torch.cat( |
577 | | - [ |
578 | | - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), |
579 | | - torch.ones((image_padding_len,), dtype=torch.bool, device=device), |
580 | | - ], |
581 | | - dim=0, |
| 597 | + for image, cap_feat in zip(all_image, all_cap_feats): |
| 598 | + # Caption |
| 599 | + cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids( |
| 600 | + cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device |
582 | 601 | ) |
583 | | - all_image_pad_mask.append( |
584 | | - image_pad_mask |
585 | | - if image_padding_len > 0 |
586 | | - else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) |
587 | | - ) |
588 | | - # padded feature |
589 | | - image_padded_feat = torch.cat( |
590 | | - [image, image[-1:].repeat(image_padding_len, 1)], |
591 | | - dim=0, |
| 602 | + all_cap_out.append(cap_out) |
| 603 | + all_cap_pos_ids.append(cap_pos_ids) |
| 604 | + all_cap_pad_mask.append(cap_pad_mask) |
| 605 | + |
| 606 | + # Image |
| 607 | + img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size) |
| 608 | + img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids( |
| 609 | + img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device |
592 | 610 | ) |
593 | | - all_image_out.append(image_padded_feat if image_padding_len > 0 else image) |
| 611 | + all_img_out.append(img_out) |
| 612 | + all_img_size.append(size) |
| 613 | + all_img_pos_ids.append(img_pos_ids) |
| 614 | + all_img_pad_mask.append(img_pad_mask) |
594 | 615 |
|
595 | 616 | return ( |
596 | | - all_image_out, |
597 | | - all_cap_feats_out, |
598 | | - all_image_size, |
599 | | - all_image_pos_ids, |
| 617 | + all_img_out, |
| 618 | + all_cap_out, |
| 619 | + all_img_size, |
| 620 | + all_img_pos_ids, |
600 | 621 | all_cap_pos_ids, |
601 | | - all_image_pad_mask, |
| 622 | + all_img_pad_mask, |
602 | 623 | all_cap_pad_mask, |
603 | 624 | ) |
604 | 625 |
|
|
0 commit comments