Skip to content

Commit 88faab1

Browse files
committed
add conversion script
1 parent 969dd17 commit 88faab1

File tree

5 files changed

+24
-4
lines changed

5 files changed

+24
-4
lines changed

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
"ControlNetModel",
107107
"ControlNetUnionModel",
108108
"ControlNetXSAdapter",
109+
"CosmosTransformer3DModel",
109110
"DiTTransformer2DModel",
110111
"FluxControlNetModel",
111112
"FluxMultiControlNetModel",
@@ -620,6 +621,7 @@
620621
ControlNetModel,
621622
ControlNetUnionModel,
622623
ControlNetXSAdapter,
624+
CosmosTransformer3DModel,
623625
DiTTransformer2DModel,
624626
FluxControlNetModel,
625627
FluxMultiControlNetModel,

src/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
7070
_import_structure["transformers.transformer_allegro"] = ["AllegroTransformer3DModel"]
7171
_import_structure["transformers.transformer_cogview3plus"] = ["CogView3PlusTransformer2DModel"]
72+
_import_structure["transformers.transformer_cosmos"] = ["CosmosTransformer3DModel"]
7273
_import_structure["transformers.transformer_flux"] = ["FluxTransformer2DModel"]
7374
_import_structure["transformers.transformer_hunyuan_video"] = ["HunyuanVideoTransformer3DModel"]
7475
_import_structure["transformers.transformer_ltx"] = ["LTXVideoTransformer3DModel"]
@@ -133,6 +134,7 @@
133134
CogVideoXTransformer3DModel,
134135
CogView3PlusTransformer2DModel,
135136
ConsisIDTransformer3DModel,
137+
CosmosTransformer3DModel,
136138
DiTTransformer2DModel,
137139
DualTransformer2DModel,
138140
FluxTransformer2DModel,

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .transformer_2d import Transformer2DModel
1919
from .transformer_allegro import AllegroTransformer3DModel
2020
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
21+
from .transformer_cosmos import CosmosTransformer3DModel
2122
from .transformer_flux import FluxTransformer2DModel
2223
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
2324
from .transformer_ltx import LTXVideoTransformer3DModel

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
356356
return (emb / norm).type_as(hidden_states)
357357

358358

359-
class CosmosTransformer(ModelMixin, ConfigMixin):
359+
class CosmosTransformer3DModel(ModelMixin, ConfigMixin):
360360
r"""
361361
A Transformer model for video-like data used in [Cosmos](https://github.com/NVIDIA/Cosmos).
362362
@@ -423,9 +423,9 @@ def __init__(
423423
hidden_size=attention_head_dim, max_size=max_size, patch_size=patch_size, rope_scale=rope_scale
424424
)
425425

426-
self.learnable_pos_embedder = None
426+
self.learnable_pos_embed = None
427427
if extra_pos_embed_type == "learnable":
428-
self.learnable_pos_embedder = CosmosLearnablePositionalEmbed(
428+
self.learnable_pos_embed = CosmosLearnablePositionalEmbed(
429429
hidden_size=hidden_size,
430430
max_size=max_size,
431431
patch_size=patch_size,
@@ -477,7 +477,7 @@ def forward(
477477

478478
# 2. Generate positional embeddings
479479
image_rotary_emb = self.rope(hidden_states, fps=fps)
480-
extra_pos_emb = self.learnable_pos_embedder(hidden_states) if self.config.extra_pos_embed_type else None
480+
extra_pos_emb = self.learnable_pos_embed(hidden_states) if self.config.extra_pos_embed_type else None
481481

482482
# 3. Patchify input
483483
batch_size, num_channels, num_frames, height, width = hidden_states.shape

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,21 @@ def from_pretrained(cls, *args, **kwargs):
351351
requires_backends(cls, ["torch"])
352352

353353

354+
class CosmosTransformer3DModel(metaclass=DummyObject):
355+
_backends = ["torch"]
356+
357+
def __init__(self, *args, **kwargs):
358+
requires_backends(self, ["torch"])
359+
360+
@classmethod
361+
def from_config(cls, *args, **kwargs):
362+
requires_backends(cls, ["torch"])
363+
364+
@classmethod
365+
def from_pretrained(cls, *args, **kwargs):
366+
requires_backends(cls, ["torch"])
367+
368+
354369
class DiTTransformer2DModel(metaclass=DummyObject):
355370
_backends = ["torch"]
356371

0 commit comments

Comments
 (0)