Skip to content

Commit 2fe5bec

Browse files
authored
Merge branch 'main' into cp-fixes-attn-backends
2 parents 6b5b3f7 + 1cdb872 commit 2fe5bec

File tree

15 files changed

+1575
-335
lines changed

15 files changed

+1575
-335
lines changed

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
BertModel,
2222
BertTokenizer,
2323
CLIPImageProcessor,
24-
MT5Tokenizer,
2524
T5EncoderModel,
25+
T5Tokenizer,
2626
)
2727

2828
from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
@@ -260,7 +260,7 @@ class HunyuanDiTDifferentialImg2ImgPipeline(DiffusionPipeline):
260260
The HunyuanDiT model designed by Tencent Hunyuan.
261261
text_encoder_2 (`T5EncoderModel`):
262262
The mT5 embedder. Specifically, it is 't5-v1_1-xxl'.
263-
tokenizer_2 (`MT5Tokenizer`):
263+
tokenizer_2 (`T5Tokenizer`):
264264
The tokenizer for the mT5 embedder.
265265
scheduler ([`DDPMScheduler`]):
266266
A scheduler to be used in combination with HunyuanDiT to denoise the encoded image latents.
@@ -295,7 +295,7 @@ def __init__(
295295
feature_extractor: CLIPImageProcessor,
296296
requires_safety_checker: bool = True,
297297
text_encoder_2=T5EncoderModel,
298-
tokenizer_2=MT5Tokenizer,
298+
tokenizer_2=T5Tokenizer,
299299
):
300300
super().__init__()
301301

scripts/convert_cosmos_to_diffusers.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,52 @@
2929
3030
Convert checkpoint
3131
```bash
32+
# pre-trained
3233
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/pre-trained/d20b7120-df3e-4911-919d-db6e08bad31c_ema_bf16.pt
3334
3435
python scripts/convert_cosmos_to_diffusers.py \
3536
--transformer_type Cosmos-2.5-Predict-Base-2B \
3637
--transformer_ckpt_path $transformer_ckpt_path \
3738
--vae_type wan2.1 \
38-
--output_path converted/cosmos-p2.5-base-2b \
39+
--output_path converted/2b/d20b7120-df3e-4911-919d-db6e08bad31c \
40+
--save_pipeline
41+
42+
# post-trained
43+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-2B/snapshots/865baf084d4c9e850eac59a021277d5a9b9e8b63/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt
44+
45+
python scripts/convert_cosmos_to_diffusers.py \
46+
--transformer_type Cosmos-2.5-Predict-Base-2B \
47+
--transformer_ckpt_path $transformer_ckpt_path \
48+
--vae_type wan2.1 \
49+
--output_path converted/2b/81edfebe-bd6a-4039-8c1d-737df1a790bf \
50+
--save_pipeline
51+
```
52+
53+
## 14B
54+
55+
```bash
56+
hf download nvidia/Cosmos-Predict2.5-14B
57+
```
58+
59+
```bash
60+
# pre-trained
61+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/pre-trained/54937b8c-29de-4f04-862c-e67b04ec41e8_ema_bf16.pt
62+
63+
python scripts/convert_cosmos_to_diffusers.py \
64+
--transformer_type Cosmos-2.5-Predict-Base-14B \
65+
--transformer_ckpt_path $transformer_ckpt_path \
66+
--vae_type wan2.1 \
67+
--output_path converted/14b/54937b8c-29de-4f04-862c-e67b04ec41e8/ \
68+
--save_pipeline
69+
70+
# post-trained
71+
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Predict2.5-14B/snapshots/71ebf3e8af30ecfe440bf0481115975fcc052b46/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt
72+
73+
python scripts/convert_cosmos_to_diffusers.py \
74+
--transformer_type Cosmos-2.5-Predict-Base-14B \
75+
--transformer_ckpt_path $transformer_ckpt_path \
76+
--vae_type wan2.1 \
77+
--output_path converted/14b/e21d2a49-4747-44c8-ba44-9f6f9243715f/ \
3978
--save_pipeline
4079
```
4180
@@ -298,6 +337,25 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
298337
"crossattn_proj_in_channels": 100352,
299338
"encoder_hidden_states_channels": 1024,
300339
},
340+
"Cosmos-2.5-Predict-Base-14B": {
341+
"in_channels": 16 + 1,
342+
"out_channels": 16,
343+
"num_attention_heads": 40,
344+
"attention_head_dim": 128,
345+
"num_layers": 36,
346+
"mlp_ratio": 4.0,
347+
"text_embed_dim": 1024,
348+
"adaln_lora_dim": 256,
349+
"max_size": (128, 240, 240),
350+
"patch_size": (1, 2, 2),
351+
"rope_scale": (1.0, 3.0, 3.0),
352+
"concat_padding_mask": True,
353+
# NOTE: source config has pos_emb_learnable: 'True' - but params are missing
354+
"extra_pos_embed_type": None,
355+
"use_crossattn_projection": True,
356+
"crossattn_proj_in_channels": 100352,
357+
"encoder_hidden_states_channels": 1024,
358+
},
301359
}
302360

303361
VAE_KEYS_RENAME_DICT = {

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,7 @@
675675
"ZImageControlNetInpaintPipeline",
676676
"ZImageControlNetPipeline",
677677
"ZImageImg2ImgPipeline",
678+
"ZImageOmniPipeline",
678679
"ZImagePipeline",
679680
]
680681
)
@@ -1386,6 +1387,7 @@
13861387
ZImageControlNetInpaintPipeline,
13871388
ZImageControlNetPipeline,
13881389
ZImageImg2ImgPipeline,
1390+
ZImageOmniPipeline,
13891391
ZImagePipeline,
13901392
)
13911393

src/diffusers/models/controlnets/controlnet_z_image.py

Lines changed: 116 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16-
from typing import List, Literal, Optional
16+
from typing import List, Literal, Optional, Tuple
1717

1818
import torch
1919
import torch.nn as nn
@@ -170,6 +170,21 @@ def forward(self, x):
170170
return self.w2(self._forward_silu_gating(self.w1(x), self.w3(x)))
171171

172172

173+
# Copied from diffusers.models.transformers.transformer_z_image.select_per_token
174+
def select_per_token(
175+
value_noisy: torch.Tensor,
176+
value_clean: torch.Tensor,
177+
noise_mask: torch.Tensor,
178+
seq_len: int,
179+
) -> torch.Tensor:
180+
noise_mask_expanded = noise_mask.unsqueeze(-1) # (batch, seq_len, 1)
181+
return torch.where(
182+
noise_mask_expanded == 1,
183+
value_noisy.unsqueeze(1).expand(-1, seq_len, -1),
184+
value_clean.unsqueeze(1).expand(-1, seq_len, -1),
185+
)
186+
187+
173188
@maybe_allow_in_graph
174189
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformerBlock
175190
class ZImageTransformerBlock(nn.Module):
@@ -220,12 +235,37 @@ def forward(
220235
attn_mask: torch.Tensor,
221236
freqs_cis: torch.Tensor,
222237
adaln_input: Optional[torch.Tensor] = None,
238+
noise_mask: Optional[torch.Tensor] = None,
239+
adaln_noisy: Optional[torch.Tensor] = None,
240+
adaln_clean: Optional[torch.Tensor] = None,
223241
):
224242
if self.modulation:
225-
assert adaln_input is not None
226-
scale_msa, gate_msa, scale_mlp, gate_mlp = self.adaLN_modulation(adaln_input).unsqueeze(1).chunk(4, dim=2)
227-
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
228-
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
243+
seq_len = x.shape[1]
244+
245+
if noise_mask is not None:
246+
# Per-token modulation: different modulation for noisy/clean tokens
247+
mod_noisy = self.adaLN_modulation(adaln_noisy)
248+
mod_clean = self.adaLN_modulation(adaln_clean)
249+
250+
scale_msa_noisy, gate_msa_noisy, scale_mlp_noisy, gate_mlp_noisy = mod_noisy.chunk(4, dim=1)
251+
scale_msa_clean, gate_msa_clean, scale_mlp_clean, gate_mlp_clean = mod_clean.chunk(4, dim=1)
252+
253+
gate_msa_noisy, gate_mlp_noisy = gate_msa_noisy.tanh(), gate_mlp_noisy.tanh()
254+
gate_msa_clean, gate_mlp_clean = gate_msa_clean.tanh(), gate_mlp_clean.tanh()
255+
256+
scale_msa_noisy, scale_mlp_noisy = 1.0 + scale_msa_noisy, 1.0 + scale_mlp_noisy
257+
scale_msa_clean, scale_mlp_clean = 1.0 + scale_msa_clean, 1.0 + scale_mlp_clean
258+
259+
scale_msa = select_per_token(scale_msa_noisy, scale_msa_clean, noise_mask, seq_len)
260+
scale_mlp = select_per_token(scale_mlp_noisy, scale_mlp_clean, noise_mask, seq_len)
261+
gate_msa = select_per_token(gate_msa_noisy, gate_msa_clean, noise_mask, seq_len)
262+
gate_mlp = select_per_token(gate_mlp_noisy, gate_mlp_clean, noise_mask, seq_len)
263+
else:
264+
# Global modulation: same modulation for all tokens (avoid double select)
265+
mod = self.adaLN_modulation(adaln_input)
266+
scale_msa, gate_msa, scale_mlp, gate_mlp = mod.unsqueeze(1).chunk(4, dim=2)
267+
gate_msa, gate_mlp = gate_msa.tanh(), gate_mlp.tanh()
268+
scale_msa, scale_mlp = 1.0 + scale_msa, 1.0 + scale_mlp
229269

230270
# Attention block
231271
attn_out = self.attention(
@@ -493,112 +533,93 @@ def from_transformer(cls, controlnet, transformer):
493533
def create_coordinate_grid(size, start=None, device=None):
494534
if start is None:
495535
start = (0 for _ in size)
496-
497536
axes = [torch.arange(x0, x0 + span, dtype=torch.int32, device=device) for x0, span in zip(start, size)]
498537
grids = torch.meshgrid(axes, indexing="ij")
499538
return torch.stack(grids, dim=-1)
500539

501-
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
502-
def patchify_and_embed(
540+
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._patchify_image
541+
def _patchify_image(self, image: torch.Tensor, patch_size: int, f_patch_size: int):
542+
"""Patchify a single image tensor: (C, F, H, W) -> (num_patches, patch_dim)."""
543+
pH, pW, pF = patch_size, patch_size, f_patch_size
544+
C, F, H, W = image.size()
545+
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
546+
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
547+
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
548+
return image, (F, H, W), (F_tokens, H_tokens, W_tokens)
549+
550+
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel._pad_with_ids
551+
def _pad_with_ids(
503552
self,
504-
all_image: List[torch.Tensor],
505-
all_cap_feats: List[torch.Tensor],
506-
patch_size: int,
507-
f_patch_size: int,
553+
feat: torch.Tensor,
554+
pos_grid_size: Tuple,
555+
pos_start: Tuple,
556+
device: torch.device,
557+
noise_mask_val: Optional[int] = None,
508558
):
509-
pH = pW = patch_size
510-
pF = f_patch_size
511-
device = all_image[0].device
512-
513-
all_image_out = []
514-
all_image_size = []
515-
all_image_pos_ids = []
516-
all_image_pad_mask = []
517-
all_cap_pos_ids = []
518-
all_cap_pad_mask = []
519-
all_cap_feats_out = []
520-
521-
for i, (image, cap_feat) in enumerate(zip(all_image, all_cap_feats)):
522-
### Process Caption
523-
cap_ori_len = len(cap_feat)
524-
cap_padding_len = (-cap_ori_len) % SEQ_MULTI_OF
525-
# padded position ids
526-
cap_padded_pos_ids = self.create_coordinate_grid(
527-
size=(cap_ori_len + cap_padding_len, 1, 1),
528-
start=(1, 0, 0),
529-
device=device,
530-
).flatten(0, 2)
531-
all_cap_pos_ids.append(cap_padded_pos_ids)
532-
# pad mask
533-
cap_pad_mask = torch.cat(
534-
[
535-
torch.zeros((cap_ori_len,), dtype=torch.bool, device=device),
536-
torch.ones((cap_padding_len,), dtype=torch.bool, device=device),
537-
],
538-
dim=0,
559+
"""Pad feature to SEQ_MULTI_OF, create position IDs and pad mask."""
560+
ori_len = len(feat)
561+
pad_len = (-ori_len) % SEQ_MULTI_OF
562+
total_len = ori_len + pad_len
563+
564+
# Pos IDs
565+
ori_pos_ids = self.create_coordinate_grid(size=pos_grid_size, start=pos_start, device=device).flatten(0, 2)
566+
if pad_len > 0:
567+
pad_pos_ids = (
568+
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
569+
.flatten(0, 2)
570+
.repeat(pad_len, 1)
539571
)
540-
all_cap_pad_mask.append(
541-
cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device)
572+
pos_ids = torch.cat([ori_pos_ids, pad_pos_ids], dim=0)
573+
padded_feat = torch.cat([feat, feat[-1:].repeat(pad_len, 1)], dim=0)
574+
pad_mask = torch.cat(
575+
[
576+
torch.zeros(ori_len, dtype=torch.bool, device=device),
577+
torch.ones(pad_len, dtype=torch.bool, device=device),
578+
]
542579
)
580+
else:
581+
pos_ids = ori_pos_ids
582+
padded_feat = feat
583+
pad_mask = torch.zeros(ori_len, dtype=torch.bool, device=device)
543584

544-
# padded feature
545-
cap_padded_feat = torch.cat([cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0)
546-
all_cap_feats_out.append(cap_padded_feat)
547-
548-
### Process Image
549-
C, F, H, W = image.size()
550-
all_image_size.append((F, H, W))
551-
F_tokens, H_tokens, W_tokens = F // pF, H // pH, W // pW
552-
553-
image = image.view(C, F_tokens, pF, H_tokens, pH, W_tokens, pW)
554-
# "c f pf h ph w pw -> (f h w) (pf ph pw c)"
555-
image = image.permute(1, 3, 5, 2, 4, 6, 0).reshape(F_tokens * H_tokens * W_tokens, pF * pH * pW * C)
585+
noise_mask = [noise_mask_val] * total_len if noise_mask_val is not None else None # token level
586+
return padded_feat, pos_ids, pad_mask, total_len, noise_mask
556587

557-
image_ori_len = len(image)
558-
image_padding_len = (-image_ori_len) % SEQ_MULTI_OF
588+
# Copied from diffusers.models.transformers.transformer_z_image.ZImageTransformer2DModel.patchify_and_embed
589+
def patchify_and_embed(
590+
self, all_image: List[torch.Tensor], all_cap_feats: List[torch.Tensor], patch_size: int, f_patch_size: int
591+
):
592+
"""Patchify for basic mode: single image per batch item."""
593+
device = all_image[0].device
594+
all_img_out, all_img_size, all_img_pos_ids, all_img_pad_mask = [], [], [], []
595+
all_cap_out, all_cap_pos_ids, all_cap_pad_mask = [], [], []
559596

560-
image_ori_pos_ids = self.create_coordinate_grid(
561-
size=(F_tokens, H_tokens, W_tokens),
562-
start=(cap_ori_len + cap_padding_len + 1, 0, 0),
563-
device=device,
564-
).flatten(0, 2)
565-
image_padded_pos_ids = torch.cat(
566-
[
567-
image_ori_pos_ids,
568-
self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device)
569-
.flatten(0, 2)
570-
.repeat(image_padding_len, 1),
571-
],
572-
dim=0,
573-
)
574-
all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids)
575-
# pad mask
576-
image_pad_mask = torch.cat(
577-
[
578-
torch.zeros((image_ori_len,), dtype=torch.bool, device=device),
579-
torch.ones((image_padding_len,), dtype=torch.bool, device=device),
580-
],
581-
dim=0,
597+
for image, cap_feat in zip(all_image, all_cap_feats):
598+
# Caption
599+
cap_out, cap_pos_ids, cap_pad_mask, cap_len, _ = self._pad_with_ids(
600+
cap_feat, (len(cap_feat) + (-len(cap_feat)) % SEQ_MULTI_OF, 1, 1), (1, 0, 0), device
582601
)
583-
all_image_pad_mask.append(
584-
image_pad_mask
585-
if image_padding_len > 0
586-
else torch.zeros((image_ori_len,), dtype=torch.bool, device=device)
587-
)
588-
# padded feature
589-
image_padded_feat = torch.cat(
590-
[image, image[-1:].repeat(image_padding_len, 1)],
591-
dim=0,
602+
all_cap_out.append(cap_out)
603+
all_cap_pos_ids.append(cap_pos_ids)
604+
all_cap_pad_mask.append(cap_pad_mask)
605+
606+
# Image
607+
img_patches, size, (F_t, H_t, W_t) = self._patchify_image(image, patch_size, f_patch_size)
608+
img_out, img_pos_ids, img_pad_mask, _, _ = self._pad_with_ids(
609+
img_patches, (F_t, H_t, W_t), (cap_len + 1, 0, 0), device
592610
)
593-
all_image_out.append(image_padded_feat if image_padding_len > 0 else image)
611+
all_img_out.append(img_out)
612+
all_img_size.append(size)
613+
all_img_pos_ids.append(img_pos_ids)
614+
all_img_pad_mask.append(img_pad_mask)
594615

595616
return (
596-
all_image_out,
597-
all_cap_feats_out,
598-
all_image_size,
599-
all_image_pos_ids,
617+
all_img_out,
618+
all_cap_out,
619+
all_img_size,
620+
all_img_pos_ids,
600621
all_cap_pos_ids,
601-
all_image_pad_mask,
622+
all_img_pad_mask,
602623
all_cap_pad_mask,
603624
)
604625

0 commit comments

Comments
 (0)