Skip to content

Commit 3e019f2

Browse files
committed
support video-to-world
1 parent 0d56c0c commit 3e019f2

File tree

7 files changed

+819
-22
lines changed

7 files changed

+819
-22
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
182182
"adaln_lora_dim": 256,
183183
"max_size": (128, 240, 240),
184184
"patch_size": (1, 2, 2),
185-
"rope_scale": (1.0, 1.0, 1.0),
185+
"rope_scale": (1.0, 4.0, 4.0),
186186
"concat_padding_mask": True,
187187
"extra_pos_embed_type": None,
188188
},
@@ -197,7 +197,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
197197
"adaln_lora_dim": 256,
198198
"max_size": (128, 240, 240),
199199
"patch_size": (1, 2, 2),
200-
"rope_scale": (20 / 24, 2.0, 2.0),
200+
"rope_scale": (1.0, 4.0, 4.0),
201201
"concat_padding_mask": True,
202202
"extra_pos_embed_type": None,
203203
},
@@ -212,7 +212,7 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
212212
"adaln_lora_dim": 256,
213213
"max_size": (128, 240, 240),
214214
"patch_size": (1, 2, 2),
215-
"rope_scale": (1.0, 1.0, 1.0),
215+
"rope_scale": (1.0, 3.0, 3.0),
216216
"concat_padding_mask": True,
217217
"extra_pos_embed_type": None,
218218
},
@@ -427,7 +427,7 @@ def save_pipeline_cosmos_2_0(args, transformer, vae):
427427
tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer_path)
428428

429429
scheduler = EDMEulerScheduler(
430-
sigma_min=0.0002,
430+
sigma_min=0.002,
431431
sigma_max=80,
432432
sigma_data=1.0,
433433
sigma_schedule="karras",

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -361,6 +361,7 @@
361361
"CogView4ControlPipeline",
362362
"CogView4Pipeline",
363363
"ConsisIDPipeline",
364+
"Cosmos2VideoToWorldPipeline",
364365
"CosmosTextToImagePipeline",
365366
"CosmosTextToWorldPipeline",
366367
"CosmosVideoToWorldPipeline",
@@ -950,6 +951,7 @@
950951
CogView4ControlPipeline,
951952
CogView4Pipeline,
952953
ConsisIDPipeline,
954+
Cosmos2VideoToWorldPipeline,
953955
CosmosTextToImagePipeline,
954956
CosmosTextToWorldPipeline,
955957
CosmosVideoToWorldPipeline,

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,15 @@ def forward(
100100
embedded_timestep = self.linear_2(embedded_timestep)
101101

102102
if temb is not None:
103-
embedded_timestep = embedded_timestep + temb[:, : 2 * self.embedding_dim]
103+
embedded_timestep = embedded_timestep + temb[..., : 2 * self.embedding_dim]
104104

105-
shift, scale = embedded_timestep.chunk(2, dim=1)
105+
shift, scale = embedded_timestep.chunk(2, dim=-1)
106106
hidden_states = self.norm(hidden_states)
107-
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
107+
108+
if embedded_timestep.ndim == 2:
109+
shift, scale = (x.unsqueeze(1) for x in (shift, scale))
110+
111+
hidden_states = hidden_states * (1 + scale) + shift
108112
return hidden_states
109113

110114

@@ -135,9 +139,13 @@ def forward(
135139
if temb is not None:
136140
embedded_timestep = embedded_timestep + temb
137141

138-
shift, scale, gate = embedded_timestep.chunk(3, dim=1)
142+
shift, scale, gate = embedded_timestep.chunk(3, dim=-1)
139143
hidden_states = self.norm(hidden_states)
140-
hidden_states = hidden_states * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
144+
145+
if embedded_timestep.ndim == 2:
146+
shift, scale, gate = (x.unsqueeze(1) for x in (shift, scale, gate))
147+
148+
hidden_states = hidden_states * (1 + scale) + shift
141149
return hidden_states, gate
142150

143151

@@ -255,19 +263,19 @@ def forward(
255263
# 1. Self Attention
256264
norm_hidden_states, gate = self.norm1(hidden_states, embedded_timestep, temb)
257265
attn_output = self.attn1(norm_hidden_states, image_rotary_emb=image_rotary_emb)
258-
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
266+
hidden_states = hidden_states + gate * attn_output
259267

260268
# 2. Cross Attention
261269
norm_hidden_states, gate = self.norm2(hidden_states, embedded_timestep, temb)
262270
attn_output = self.attn2(
263271
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
264272
)
265-
hidden_states = hidden_states + gate.unsqueeze(1) * attn_output
273+
hidden_states = hidden_states + gate * attn_output
266274

267275
# 3. Feed Forward
268276
norm_hidden_states, gate = self.norm3(hidden_states, embedded_timestep, temb)
269277
ff_output = self.ff(norm_hidden_states)
270-
hidden_states = hidden_states + gate.unsqueeze(1) * ff_output
278+
hidden_states = hidden_states + gate * ff_output
271279

272280
return hidden_states
273281

@@ -513,7 +521,23 @@ def forward(
513521
hidden_states = hidden_states.flatten(1, 3) # [B, T, H, W, C] -> [B, THW, C]
514522

515523
# 4. Timestep embeddings
516-
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
524+
if timestep.ndim == 1:
525+
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
526+
elif timestep.ndim == 5:
527+
assert timestep.shape == (batch_size, 1, num_frames, 1, 1), (
528+
f"Expected timestep to have shape [B, 1, T, 1, 1], but got {timestep.shape}"
529+
)
530+
timestep = timestep.flatten()
531+
temb, embedded_timestep = self.time_embed(hidden_states, timestep)
532+
# We can do this because num_frames == post_patch_num_frames, as p_t is 1
533+
temb, embedded_timestep = (
534+
x.view(batch_size, post_patch_num_frames, 1, 1, -1)
535+
.expand(-1, -1, post_patch_height, post_patch_width, -1)
536+
.flatten(1, 3)
537+
for x in (temb, embedded_timestep)
538+
) # [BT, C] -> [B, T, 1, 1, C] -> [B, T, H, W, C] -> [B, THW, C]
539+
else:
540+
assert False
517541

518542
# 5. Transformer blocks
519543
for block in self.transformer_blocks:

src/diffusers/pipelines/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@
161161
"CosmosTextToImagePipeline",
162162
"CosmosTextToWorldPipeline",
163163
"CosmosVideoToWorldPipeline",
164+
"Cosmos2VideoToWorldPipeline",
164165
]
165166
_import_structure["controlnet"].extend(
166167
[
@@ -563,7 +564,12 @@
563564
StableDiffusionControlNetXSPipeline,
564565
StableDiffusionXLControlNetXSPipeline,
565566
)
566-
from .cosmos import CosmosTextToImagePipeline, CosmosTextToWorldPipeline, CosmosVideoToWorldPipeline
567+
from .cosmos import (
568+
Cosmos2VideoToWorldPipeline,
569+
CosmosTextToImagePipeline,
570+
CosmosTextToWorldPipeline,
571+
CosmosVideoToWorldPipeline,
572+
)
567573
from .deepfloyd_if import (
568574
IFImg2ImgPipeline,
569575
IFImg2ImgSuperResolutionPipeline,

src/diffusers/pipelines/cosmos/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
25+
_import_structure["pipeline_cosmos2_video2world"] = ["Cosmos2VideoToWorldPipeline"]
2526
_import_structure["pipeline_cosmos_text2image"] = ["CosmosTextToImagePipeline"]
2627
_import_structure["pipeline_cosmos_text2world"] = ["CosmosTextToWorldPipeline"]
2728
_import_structure["pipeline_cosmos_video2world"] = ["CosmosVideoToWorldPipeline"]
@@ -34,6 +35,7 @@
3435
except OptionalDependencyNotAvailable:
3536
from ...utils.dummy_torch_and_transformers_objects import *
3637
else:
38+
from .pipeline_cosmos2_video2world import Cosmos2VideoToWorldPipeline
3739
from .pipeline_cosmos_text2image import CosmosTextToImagePipeline
3840
from .pipeline_cosmos_text2world import CosmosTextToWorldPipeline
3941
from .pipeline_cosmos_video2world import CosmosVideoToWorldPipeline

0 commit comments

Comments
 (0)