Skip to content

Commit a592f74

Browse files
committed
update
1 parent 64fc4fe commit a592f74

File tree

6 files changed

+661
-11
lines changed

6 files changed

+661
-11
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
4848
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
4949
>>> pipe = CosmosPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
50-
>>> pipe.vae.enable_tiling()
5150
>>> pipe.to("cuda")
5251
5352
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
@@ -540,6 +539,8 @@ def __call__(
540539
padding_mask=padding_mask,
541540
return_dict=False,
542541
)[0]
542+
543+
sample = latents
543544
if self.do_classifier_free_guidance:
544545
noise_pred_uncond = self.transformer(
545546
hidden_states=latent_model_input,
@@ -550,9 +551,10 @@ def __call__(
550551
return_dict=False,
551552
)[0]
552553
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
554+
sample = torch.cat([sample, sample])
553555

554556
# pred_original_sample (x0)
555-
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[1]
557+
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
556558
self.scheduler._step_index -= 1
557559

558560
if self.do_classifier_free_guidance:

src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,20 +41,47 @@
4141

4242
EXAMPLE_DOC_STRING = """
4343
Examples:
44+
Image conditioning:
45+
46+
```python
47+
>>> import torch
48+
>>> from diffusers import CosmosVideoToWorldPipeline
49+
>>> from diffusers.utils import export_to_video, load_image
50+
51+
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
52+
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
53+
>>> pipe.to("cuda")
54+
55+
>>> prompt = "The video depicts a long, straight highway stretching into the distance, flanked by metal guardrails. The road is divided into multiple lanes, with a few vehicles visible in the far distance. The surrounding landscape features dry, grassy fields on one side and rolling hills on the other. The sky is mostly clear with a few scattered clouds, suggesting a bright, sunny day."
56+
>>> image = load_image(
57+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input.jpg"
58+
... )
59+
60+
>>> video = pipe(image=image, prompt=prompt).frames[0]
61+
>>> export_to_video(video, "output.mp4", fps=30)
62+
```
63+
64+
Video conditioning:
65+
4466
```python
4567
>>> import torch
46-
>>> from diffusers import CosmosPipeline
47-
>>> from diffusers.utils import export_to_video
68+
>>> from diffusers import CosmosVideoToWorldPipeline
69+
>>> from diffusers.utils import export_to_video, load_video
4870
49-
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"
50-
>>> pipe = CosmosPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
51-
>>> pipe.vae.enable_tiling()
71+
>>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"
72+
>>> pipe = CosmosVideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
73+
>>> pipe.transformer = torch.compile(pipe.transformer)
5274
>>> pipe.to("cuda")
5375
54-
>>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect."
76+
>>> prompt = "The video depicts a winding mountain road covered in snow, with a single vehicle traveling along it. The road is flanked by steep, rocky cliffs and sparse vegetation. The landscape is characterized by rugged terrain and a river visible in the distance. The scene captures the solitude and beauty of a winter drive through a mountainous region."
77+
>>> video = load_video(
78+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/cosmos/cosmos-video2world-input-vid.mp4"
79+
... )[
80+
... :21
81+
... ] # This example uses only the first 21 frames
5582
56-
>>> output = pipe(prompt=prompt).frames[0]
57-
>>> export_to_video(output, "output.mp4", fps=30)
83+
>>> video = pipe(video=video, prompt=prompt).frames[0]
84+
>>> export_to_video(video, "output.mp4", fps=30)
5885
```
5986
"""
6087

@@ -654,6 +681,7 @@ def __call__(
654681
return_dict=False,
655682
)[0]
656683

684+
sample = latents
657685
if self.do_classifier_free_guidance:
658686
current_uncond_indicator = uncond_indicator * 0 if is_augment_sigma_greater else uncond_indicator
659687
uncond_noise = randn_tensor(latents.shape, generator=generator, device=device, dtype=torch.float32)
@@ -673,9 +701,10 @@ def __call__(
673701
return_dict=False,
674702
)[0]
675703
noise_pred = torch.cat([noise_pred_uncond, noise_pred])
704+
sample = torch.cat([sample, sample])
676705

677706
# pred_original_sample (x0)
678-
noise_pred = self.scheduler.step(noise_pred, t, latents, return_dict=False)[1]
707+
noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1]
679708
self.scheduler._step_index -= 1
680709

681710
if self.do_classifier_free_guidance:

tests/models/transformers/test_models_transformer_cosmos.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,68 @@ def prepare_init_args_and_inputs_for_common(self):
8686
def test_gradient_checkpointing_is_applied(self):
8787
expected_set = {"CosmosTransformer3DModel"}
8888
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
89+
90+
91+
class CosmosTransformer3DModelVideoToWorldTests(ModelTesterMixin, unittest.TestCase):
92+
model_class = CosmosTransformer3DModel
93+
main_input_name = "hidden_states"
94+
uses_custom_attn_processor = True
95+
96+
@property
97+
def dummy_input(self):
98+
batch_size = 1
99+
num_channels = 4
100+
num_frames = 1
101+
height = 16
102+
width = 16
103+
text_embed_dim = 16
104+
sequence_length = 12
105+
fps = 30
106+
107+
hidden_states = torch.randn((batch_size, num_channels, num_frames, height, width)).to(torch_device)
108+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
109+
encoder_hidden_states = torch.randn((batch_size, sequence_length, text_embed_dim)).to(torch_device)
110+
attention_mask = torch.ones((batch_size, sequence_length)).to(torch_device)
111+
condition_mask = torch.ones(batch_size, 1, num_frames, height, width).to(torch_device)
112+
padding_mask = torch.zeros(batch_size, 1, height, width).to(torch_device)
113+
114+
return {
115+
"hidden_states": hidden_states,
116+
"timestep": timestep,
117+
"encoder_hidden_states": encoder_hidden_states,
118+
"attention_mask": attention_mask,
119+
"fps": fps,
120+
"condition_mask": condition_mask,
121+
"padding_mask": padding_mask,
122+
}
123+
124+
@property
125+
def input_shape(self):
126+
return (4, 1, 16, 16)
127+
128+
@property
129+
def output_shape(self):
130+
return (4, 1, 16, 16)
131+
132+
def prepare_init_args_and_inputs_for_common(self):
133+
init_dict = {
134+
"in_channels": 4 + 1,
135+
"out_channels": 4,
136+
"num_attention_heads": 2,
137+
"attention_head_dim": 12,
138+
"num_layers": 2,
139+
"mlp_ratio": 2,
140+
"text_embed_dim": 16,
141+
"adaln_lora_dim": 4,
142+
"max_size": (4, 32, 32),
143+
"patch_size": (1, 2, 2),
144+
"rope_scale": (2.0, 1.0, 1.0),
145+
"concat_padding_mask": True,
146+
"extra_pos_embed_type": "learnable",
147+
}
148+
inputs_dict = self.dummy_input
149+
return init_dict, inputs_dict
150+
151+
def test_gradient_checkpointing_is_applied(self):
152+
expected_set = {"CosmosTransformer3DModel"}
153+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/pipelines/cosmos/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)